• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp::Ordering;
6 use std::convert::TryFrom;
7 use std::ffi::OsString;
8 use std::fs::remove_file;
9 use std::io;
10 use std::mem;
11 use std::mem::size_of;
12 use std::net::SocketAddr;
13 use std::net::SocketAddrV4;
14 use std::net::SocketAddrV6;
15 use std::net::TcpListener;
16 use std::net::TcpStream;
17 use std::net::ToSocketAddrs;
18 use std::ops::Deref;
19 use std::os::unix::ffi::OsStrExt;
20 use std::os::unix::ffi::OsStringExt;
21 use std::os::unix::io::RawFd;
22 use std::path::Path;
23 use std::path::PathBuf;
24 use std::ptr::null_mut;
25 use std::time::Duration;
26 use std::time::Instant;
27 
28 use libc::c_int;
29 use libc::in6_addr;
30 use libc::in_addr;
31 use libc::recvfrom;
32 use libc::sa_family_t;
33 use libc::sockaddr;
34 use libc::sockaddr_in;
35 use libc::sockaddr_in6;
36 use libc::socklen_t;
37 use libc::AF_INET;
38 use libc::AF_INET6;
39 use libc::MSG_PEEK;
40 use libc::MSG_TRUNC;
41 use libc::SOCK_CLOEXEC;
42 use libc::SOCK_STREAM;
43 use log::warn;
44 use serde::Deserialize;
45 use serde::Serialize;
46 
47 use super::sock_ctrl_msg::ScmSocket;
48 use super::sock_ctrl_msg::SCM_SOCKET_MAX_FD_COUNT;
49 use super::Error;
50 use super::RawDescriptor;
51 use crate::descriptor::AsRawDescriptor;
52 use crate::descriptor::FromRawDescriptor;
53 use crate::SafeDescriptor;
54 
55 /// Assist in handling both IP version 4 and IP version 6.
56 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
57 pub enum InetVersion {
58     V4,
59     V6,
60 }
61 
62 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self63     pub fn from_sockaddr(s: &SocketAddr) -> Self {
64         match s {
65             SocketAddr::V4(_) => InetVersion::V4,
66             SocketAddr::V6(_) => InetVersion::V6,
67         }
68     }
69 }
70 
71 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t72     fn from(v: InetVersion) -> sa_family_t {
73         match v {
74             InetVersion::V4 => AF_INET as sa_family_t,
75             InetVersion::V6 => AF_INET6 as sa_family_t,
76         }
77     }
78 }
79 
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in80 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
81     sockaddr_in {
82         sin_family: AF_INET as sa_family_t,
83         sin_port: s.port().to_be(),
84         sin_addr: in_addr {
85             s_addr: u32::from_ne_bytes(s.ip().octets()),
86         },
87         sin_zero: [0; 8],
88     }
89 }
90 
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in691 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
92     sockaddr_in6 {
93         sin6_family: AF_INET6 as sa_family_t,
94         sin6_port: s.port().to_be(),
95         sin6_flowinfo: 0,
96         sin6_addr: in6_addr {
97             s6_addr: s.ip().octets(),
98         },
99         sin6_scope_id: 0,
100     }
101 }
102 
socket(domain: c_int, sock_type: c_int, protocol: c_int) -> io::Result<SafeDescriptor>103 fn socket(domain: c_int, sock_type: c_int, protocol: c_int) -> io::Result<SafeDescriptor> {
104     // Safe socket initialization since we handle the returned error.
105     match unsafe { libc::socket(domain, sock_type, protocol) } {
106         -1 => Err(io::Error::last_os_error()),
107         // Safe because we own the file descriptor.
108         fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
109     }
110 }
111 
socketpair( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<(SafeDescriptor, SafeDescriptor)>112 fn socketpair(
113     domain: c_int,
114     sock_type: c_int,
115     protocol: c_int,
116 ) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
117     let mut fds = [0, 0];
118     // Safe because we give enough space to store all the fds and we check the return value.
119     match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
120         -1 => Err(io::Error::last_os_error()),
121         // Safe because we own the file descriptors.
122         _ => Ok(unsafe {
123             (
124                 SafeDescriptor::from_raw_descriptor(fds[0]),
125                 SafeDescriptor::from_raw_descriptor(fds[1]),
126             )
127         }),
128     }
129 }
130 
131 /// A TCP socket.
132 ///
133 /// Do not use this class unless you need to change socket options or query the
134 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
135 /// TcpListener.
136 #[derive(Debug)]
137 pub struct TcpSocket {
138     inet_version: InetVersion,
139     descriptor: SafeDescriptor,
140 }
141 
142 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>143     pub fn new(inet_version: InetVersion) -> io::Result<Self> {
144         Ok(TcpSocket {
145             inet_version,
146             descriptor: socket(
147                 Into::<sa_family_t>::into(inet_version) as c_int,
148                 SOCK_STREAM | SOCK_CLOEXEC,
149                 0,
150             )?,
151         })
152     }
153 
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>154     pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
155         let sockaddr = addr
156             .to_socket_addrs()
157             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
158             .next()
159             .unwrap();
160 
161         let ret = match sockaddr {
162             SocketAddr::V4(a) => {
163                 let sin = sockaddrv4_to_lib_c(&a);
164                 // Safe because this doesn't modify any memory and we check the return value.
165                 unsafe {
166                     libc::bind(
167                         self.as_raw_descriptor(),
168                         &sin as *const sockaddr_in as *const sockaddr,
169                         size_of::<sockaddr_in>() as socklen_t,
170                     )
171                 }
172             }
173             SocketAddr::V6(a) => {
174                 let sin6 = sockaddrv6_to_lib_c(&a);
175                 // Safe because this doesn't modify any memory and we check the return value.
176                 unsafe {
177                     libc::bind(
178                         self.as_raw_descriptor(),
179                         &sin6 as *const sockaddr_in6 as *const sockaddr,
180                         size_of::<sockaddr_in6>() as socklen_t,
181                     )
182                 }
183             }
184         };
185         if ret < 0 {
186             let bind_err = io::Error::last_os_error();
187             Err(bind_err)
188         } else {
189             Ok(())
190         }
191     }
192 
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>193     pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
194         let sockaddr = addr
195             .to_socket_addrs()
196             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
197             .next()
198             .unwrap();
199 
200         let ret = match sockaddr {
201             SocketAddr::V4(a) => {
202                 let sin = sockaddrv4_to_lib_c(&a);
203                 // Safe because this doesn't modify any memory and we check the return value.
204                 unsafe {
205                     libc::connect(
206                         self.as_raw_descriptor(),
207                         &sin as *const sockaddr_in as *const sockaddr,
208                         size_of::<sockaddr_in>() as socklen_t,
209                     )
210                 }
211             }
212             SocketAddr::V6(a) => {
213                 let sin6 = sockaddrv6_to_lib_c(&a);
214                 // Safe because this doesn't modify any memory and we check the return value.
215                 unsafe {
216                     libc::connect(
217                         self.as_raw_descriptor(),
218                         &sin6 as *const sockaddr_in6 as *const sockaddr,
219                         size_of::<sockaddr_in>() as socklen_t,
220                     )
221                 }
222             }
223         };
224 
225         if ret < 0 {
226             let connect_err = io::Error::last_os_error();
227             Err(connect_err)
228         } else {
229             Ok(TcpStream::from(self.descriptor))
230         }
231     }
232 
listen(self) -> io::Result<TcpListener>233     pub fn listen(self) -> io::Result<TcpListener> {
234         // Safe because this doesn't modify any memory and we check the return value.
235         let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
236         if ret < 0 {
237             let listen_err = io::Error::last_os_error();
238             Err(listen_err)
239         } else {
240             Ok(TcpListener::from(self.descriptor))
241         }
242     }
243 
244     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>245     pub fn local_port(&self) -> io::Result<u16> {
246         match self.inet_version {
247             InetVersion::V4 => {
248                 let mut sin = sockaddr_in {
249                     sin_family: 0,
250                     sin_port: 0,
251                     sin_addr: in_addr { s_addr: 0 },
252                     sin_zero: [0; 8],
253                 };
254 
255                 // Safe because we give a valid pointer for addrlen and check the length.
256                 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
257                 let ret = unsafe {
258                     // Get the socket address that was actually bound.
259                     libc::getsockname(
260                         self.as_raw_descriptor(),
261                         &mut sin as *mut sockaddr_in as *mut sockaddr,
262                         &mut addrlen as *mut socklen_t,
263                     )
264                 };
265                 if ret < 0 {
266                     let getsockname_err = io::Error::last_os_error();
267                     Err(getsockname_err)
268                 } else {
269                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
270                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
271 
272                     Ok(u16::from_be(sin.sin_port))
273                 }
274             }
275             InetVersion::V6 => {
276                 let mut sin6 = sockaddr_in6 {
277                     sin6_family: 0,
278                     sin6_port: 0,
279                     sin6_flowinfo: 0,
280                     sin6_addr: in6_addr { s6_addr: [0; 16] },
281                     sin6_scope_id: 0,
282                 };
283 
284                 // Safe because we give a valid pointer for addrlen and check the length.
285                 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
286                 let ret = unsafe {
287                     // Get the socket address that was actually bound.
288                     libc::getsockname(
289                         self.as_raw_descriptor(),
290                         &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
291                         &mut addrlen as *mut socklen_t,
292                     )
293                 };
294                 if ret < 0 {
295                     let getsockname_err = io::Error::last_os_error();
296                     Err(getsockname_err)
297                 } else {
298                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
299                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
300 
301                     Ok(u16::from_be(sin6.sin6_port))
302                 }
303             }
304         }
305     }
306 }
307 
308 impl AsRawDescriptor for TcpSocket {
as_raw_descriptor(&self) -> RawDescriptor309     fn as_raw_descriptor(&self) -> RawDescriptor {
310         self.descriptor.as_raw_descriptor()
311     }
312 }
313 
314 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize315 fn sun_path_offset() -> usize {
316     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
317     #[allow(clippy::zero_ptr)]
318     let addr = 0 as *const libc::sockaddr_un;
319     // Safe because we only use the dereference to create a pointer to the desired field in
320     // calculating the offset.
321     unsafe { &(*addr).sun_path as *const _ as usize }
322 }
323 
324 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>325 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
326     let mut addr = libc::sockaddr_un {
327         sun_family: libc::AF_UNIX as libc::sa_family_t,
328         sun_path: [0; 108],
329     };
330 
331     // Check if the input path is valid. Since
332     // * The pathname in sun_path should be null-terminated.
333     // * The length of the pathname, including the terminating null byte,
334     //   should not exceed the size of sun_path.
335     //
336     // and our input is a `Path`, we only need to check
337     // * If the string size of `Path` should less than sizeof(sun_path)
338     // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
339     //
340     // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
341     let bytes = path.as_ref().as_os_str().as_bytes();
342     if bytes.len() >= addr.sun_path.len() {
343         return Err(io::Error::new(
344             io::ErrorKind::InvalidInput,
345             "Input path size should be less than the length of sun_path.",
346         ));
347     };
348 
349     // Copy data from `path` to `addr.sun_path`
350     for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
351         *dst = *src as libc::c_char;
352     }
353 
354     // The addrlen argument that describes the enclosing sockaddr_un structure
355     // should have a value of at least:
356     //
357     //     offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
358     //
359     // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
360     let len = sun_path_offset() + bytes.len() + 1;
361     Ok((addr, len as libc::socklen_t))
362 }
363 
364 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
365 #[derive(Debug, Serialize, Deserialize)]
366 pub struct UnixSeqpacket(SafeDescriptor);
367 
368 // Safety: UnixSeqpacket is just a file descriptor, and it is safe to send FDs
369 // between threads.
370 unsafe impl Send for UnixSeqpacket {}
371 
372 impl UnixSeqpacket {
373     /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
374     ///
375     /// # Arguments
376     /// * `path` - Path to `SOCK_SEQPACKET` socket
377     ///
378     /// # Returns
379     /// A `UnixSeqpacket` structure point to the socket
380     ///
381     /// # Errors
382     /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>383     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
384         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
385         let (addr, len) = sockaddr_un(path.as_ref())?;
386         // Safe connect since we handle the error and use the right length generated from
387         // `sockaddr_un`.
388         unsafe {
389             let ret = libc::connect(
390                 descriptor.as_raw_descriptor(),
391                 &addr as *const _ as *const _,
392                 len,
393             );
394             if ret < 0 {
395                 return Err(io::Error::last_os_error());
396             }
397         }
398         Ok(UnixSeqpacket(descriptor))
399     }
400 
401     /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
402     ///
403     /// Both returned file descriptors have the `CLOEXEC` flag set.
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>404     pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
405         socketpair(libc::AF_UNIX, libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC, 0)
406             .map(|(s0, s1)| (UnixSeqpacket::from(s0), UnixSeqpacket::from(s1)))
407     }
408 
409     /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>410     pub fn try_clone(&self) -> io::Result<Self> {
411         Ok(Self(self.0.try_clone()?))
412     }
413 
414     /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>415     pub fn get_readable_bytes(&self) -> io::Result<usize> {
416         let mut byte_count = 0i32;
417         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
418         if ret < 0 {
419             Err(io::Error::last_os_error())
420         } else {
421             Ok(byte_count as usize)
422         }
423     }
424 
425     /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
426     /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>427     pub fn next_packet_size(&self) -> io::Result<usize> {
428         #[cfg(not(debug_assertions))]
429         let buf = null_mut();
430         // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
431         // This only matters for running the unit tests for a non-native architecture. See the
432         // upstream thread for the qemu fix:
433         // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
434         #[cfg(debug_assertions)]
435         let buf = &mut 0 as *mut _ as *mut _;
436 
437         // This form of recvfrom doesn't modify any data because all null pointers are used. We only
438         // use the return value and check for errors on an FD owned by this structure.
439         let ret = unsafe {
440             recvfrom(
441                 self.as_raw_descriptor(),
442                 buf,
443                 0,
444                 MSG_TRUNC | MSG_PEEK,
445                 null_mut(),
446                 null_mut(),
447             )
448         };
449         if ret < 0 {
450             Err(io::Error::last_os_error())
451         } else {
452             Ok(ret as usize)
453         }
454     }
455 
456     /// Write data from a given buffer to the socket fd
457     ///
458     /// # Arguments
459     /// * `buf` - A reference to the data buffer.
460     ///
461     /// # Returns
462     /// * `usize` - The size of bytes written to the buffer.
463     ///
464     /// # Errors
465     /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>466     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
467         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
468         unsafe {
469             let ret = libc::write(
470                 self.as_raw_descriptor(),
471                 buf.as_ptr() as *const _,
472                 buf.len(),
473             );
474             if ret < 0 {
475                 Err(io::Error::last_os_error())
476             } else {
477                 Ok(ret as usize)
478             }
479         }
480     }
481 
482     /// Read data from the socket fd to a given buffer
483     ///
484     /// # Arguments
485     /// * `buf` - A mut reference to the data buffer.
486     ///
487     /// # Returns
488     /// * `usize` - The size of bytes read to the buffer.
489     ///
490     /// # Errors
491     /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>492     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
493         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
494         unsafe {
495             let ret = libc::read(
496                 self.as_raw_descriptor(),
497                 buf.as_mut_ptr() as *mut _,
498                 buf.len(),
499             );
500             if ret < 0 {
501                 Err(io::Error::last_os_error())
502             } else {
503                 Ok(ret as usize)
504             }
505         }
506     }
507 
508     /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
509     ///
510     /// # Arguments
511     /// * `buf` - A mut reference to a `Vec` to resize and read into.
512     ///
513     /// # Errors
514     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>515     pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
516         let packet_size = self.next_packet_size()?;
517         buf.resize(packet_size, 0);
518         let read_bytes = self.recv(buf)?;
519         buf.resize(read_bytes, 0);
520         Ok(())
521     }
522 
523     /// Read data from the socket fd to a new `Vec`.
524     ///
525     /// # Returns
526     /// * `vec` - A new `Vec` with the entire received packet.
527     ///
528     /// # Errors
529     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>530     pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
531         let mut buf = Vec::new();
532         self.recv_to_vec(&mut buf)?;
533         Ok(buf)
534     }
535 
536     /// Read data and fds from the socket fd to a new pair of `Vec`.
537     ///
538     /// # Returns
539     /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
540     /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
541     ///
542     /// # Errors
543     /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>544     pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
545         let packet_size = self.next_packet_size()?;
546         let mut buf = vec![0; packet_size];
547         let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
548         let (read_bytes, read_fds) =
549             self.recv_with_fds(io::IoSliceMut::new(&mut buf), &mut fd_buf)?;
550         buf.resize(read_bytes, 0);
551         fd_buf.resize(read_fds, -1);
552         Ok((buf, fd_buf))
553     }
554 
555     #[allow(clippy::useless_conversion)]
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>556     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
557         let timeval = match timeout {
558             Some(t) => {
559                 if t.as_secs() == 0 && t.subsec_micros() == 0 {
560                     return Err(io::Error::new(
561                         io::ErrorKind::InvalidInput,
562                         "zero timeout duration is invalid",
563                     ));
564                 }
565                 // subsec_micros fits in i32 because it is defined to be less than one million.
566                 let nsec = t.subsec_micros() as i32;
567                 libc::timeval {
568                     tv_sec: t.as_secs() as libc::time_t,
569                     tv_usec: libc::suseconds_t::from(nsec),
570                 }
571             }
572             None => libc::timeval {
573                 tv_sec: 0,
574                 tv_usec: 0,
575             },
576         };
577         // Safe because we own the fd, and the length of the pointer's data is the same as the
578         // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
579         // and the return value is checked.
580         let ret = unsafe {
581             libc::setsockopt(
582                 self.as_raw_descriptor(),
583                 libc::SOL_SOCKET,
584                 kind,
585                 &timeval as *const libc::timeval as *const libc::c_void,
586                 mem::size_of::<libc::timeval>() as libc::socklen_t,
587             )
588         };
589         if ret < 0 {
590             Err(io::Error::last_os_error())
591         } else {
592             Ok(())
593         }
594     }
595 
596     /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>597     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
598         self.set_timeout(timeout, libc::SO_RCVTIMEO)
599     }
600 
601     /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>602     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
603         self.set_timeout(timeout, libc::SO_SNDTIMEO)
604     }
605 
606     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>607     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
608         let mut nonblocking = nonblocking as libc::c_int;
609         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
610         // and does not continue holding the file descriptor after the call.
611         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
612         if ret < 0 {
613             Err(io::Error::last_os_error())
614         } else {
615             Ok(())
616         }
617     }
618 }
619 
620 impl From<UnixSeqpacket> for SafeDescriptor {
from(s: UnixSeqpacket) -> Self621     fn from(s: UnixSeqpacket) -> Self {
622         s.0
623     }
624 }
625 
626 impl From<SafeDescriptor> for UnixSeqpacket {
from(s: SafeDescriptor) -> Self627     fn from(s: SafeDescriptor) -> Self {
628         Self(s)
629     }
630 }
631 
632 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self633     unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
634         Self(SafeDescriptor::from_raw_descriptor(descriptor))
635     }
636 }
637 
638 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor639     fn as_raw_descriptor(&self) -> RawDescriptor {
640         self.0.as_raw_descriptor()
641     }
642 }
643 
644 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>645     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
646         self.recv(buf)
647     }
648 }
649 
650 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>651     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
652         self.send(buf)
653     }
654 
flush(&mut self) -> io::Result<()>655     fn flush(&mut self) -> io::Result<()> {
656         Ok(())
657     }
658 }
659 
660 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
661 pub struct UnixSeqpacketListener {
662     descriptor: SafeDescriptor,
663     no_path: bool,
664 }
665 
666 impl UnixSeqpacketListener {
667     /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>668     pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
669         if path.as_ref().starts_with("/proc/self/fd/") {
670             let fd = path
671                 .as_ref()
672                 .file_name()
673                 .expect("Failed to get fd filename")
674                 .to_str()
675                 .expect("fd filename should be unicode")
676                 .parse::<i32>()
677                 .expect("fd should be an integer");
678             let mut result: c_int = 0;
679             let mut result_len = size_of::<c_int>() as libc::socklen_t;
680             let ret = unsafe {
681                 libc::getsockopt(
682                     fd,
683                     libc::SOL_SOCKET,
684                     libc::SO_ACCEPTCONN,
685                     &mut result as *mut _ as *mut libc::c_void,
686                     &mut result_len,
687                 )
688             };
689             if ret < 0 {
690                 return Err(io::Error::last_os_error());
691             }
692             if result != 1 {
693                 return Err(io::Error::new(
694                     io::ErrorKind::InvalidInput,
695                     "specified descriptor is not a listening socket",
696                 ));
697             }
698             // Safe because we validated the socket file descriptor.
699             let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
700             return Ok(UnixSeqpacketListener {
701                 descriptor,
702                 no_path: true,
703             });
704         }
705 
706         let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
707         let (addr, len) = sockaddr_un(path.as_ref())?;
708 
709         // Safe connect since we handle the error and use the right length generated from
710         // `sockaddr_un`.
711         unsafe {
712             let ret = handle_eintr_errno!(libc::bind(
713                 descriptor.as_raw_descriptor(),
714                 &addr as *const _ as *const _,
715                 len
716             ));
717             if ret < 0 {
718                 return Err(io::Error::last_os_error());
719             }
720             let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
721             if ret < 0 {
722                 return Err(io::Error::last_os_error());
723             }
724         }
725         Ok(UnixSeqpacketListener {
726             descriptor,
727             no_path: false,
728         })
729     }
730 
731     /// Blocks for and accepts a new incoming connection and returns the socket associated with that
732     /// connection.
733     ///
734     /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>735     pub fn accept(&self) -> io::Result<UnixSeqpacket> {
736         // Safe because we own this fd and the kernel will not write to null pointers.
737         match unsafe {
738             libc::accept4(
739                 self.as_raw_descriptor(),
740                 null_mut(),
741                 null_mut(),
742                 libc::SOCK_CLOEXEC,
743             )
744         } {
745             -1 => Err(io::Error::last_os_error()),
746             fd => {
747                 // Safe because we checked the return value of accept. Therefore, the return value
748                 // must be a valid socket.
749                 Ok(UnixSeqpacket::from(unsafe {
750                     SafeDescriptor::from_raw_descriptor(fd)
751                 }))
752             }
753         }
754     }
755 
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>756     pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
757         let start = Instant::now();
758 
759         loop {
760             let mut fds = libc::pollfd {
761                 fd: self.as_raw_descriptor(),
762                 events: libc::POLLIN,
763                 revents: 0,
764             };
765             let elapsed = Instant::now().saturating_duration_since(start);
766             let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
767             let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
768             // Safe because we give a valid pointer to a list (of 1) FD and we check
769             // the return value.
770             match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
771                 Ordering::Greater => return self.accept(),
772                 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
773                 Ordering::Less => {
774                     if Error::last() != Error::new(libc::EINTR) {
775                         return Err(io::Error::last_os_error());
776                     }
777                 }
778             }
779         }
780     }
781 
782     /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>783     pub fn path(&self) -> io::Result<PathBuf> {
784         let mut addr = libc::sockaddr_un {
785             sun_family: libc::AF_UNIX as libc::sa_family_t,
786             sun_path: [0; 108],
787         };
788         if self.no_path {
789             return Err(io::Error::new(
790                 io::ErrorKind::InvalidInput,
791                 "socket has no path",
792             ));
793         }
794         let sun_path_offset = (&addr.sun_path as *const _ as usize
795             - &addr.sun_family as *const _ as usize)
796             as libc::socklen_t;
797         let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
798         // Safe because the length given matches the length of the data of the given pointer, and we
799         // check the return value.
800         let ret = unsafe {
801             handle_eintr_errno!(libc::getsockname(
802                 self.as_raw_descriptor(),
803                 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
804                 &mut len
805             ))
806         };
807         if ret < 0 {
808             return Err(io::Error::last_os_error());
809         }
810         if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
811             || addr.sun_path[0] == 0
812             || len < 1 + sun_path_offset
813         {
814             return Err(io::Error::new(
815                 io::ErrorKind::InvalidInput,
816                 "getsockname on socket returned invalid value",
817             ));
818         }
819 
820         let path_os_str = OsString::from_vec(
821             addr.sun_path[..(len - sun_path_offset - 1) as usize]
822                 .iter()
823                 .map(|&c| c as _)
824                 .collect(),
825         );
826         Ok(path_os_str.into())
827     }
828 
829     /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>830     pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
831         let mut nonblocking = nonblocking as libc::c_int;
832         // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
833         // and does not continue holding the file descriptor after the call.
834         let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
835         if ret < 0 {
836             Err(io::Error::last_os_error())
837         } else {
838             Ok(())
839         }
840     }
841 }
842 
843 impl AsRawDescriptor for UnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor844     fn as_raw_descriptor(&self) -> RawDescriptor {
845         self.descriptor.as_raw_descriptor()
846     }
847 }
848 
849 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
850 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
851 
852 impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor853     fn as_raw_descriptor(&self) -> RawDescriptor {
854         self.0.as_raw_descriptor()
855     }
856 }
857 
858 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener859     fn as_ref(&self) -> &UnixSeqpacketListener {
860         &self.0
861     }
862 }
863 
864 impl Deref for UnlinkUnixSeqpacketListener {
865     type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target866     fn deref(&self) -> &Self::Target {
867         &self.0
868     }
869 }
870 
871 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)872     fn drop(&mut self) {
873         if let Ok(path) = self.0.path() {
874             if let Err(e) = remove_file(path) {
875                 warn!("failed to remove control socket file: {:?}", e);
876             }
877         }
878     }
879 }
880 
881 #[cfg(test)]
882 mod tests {
883     use super::*;
884 
885     #[test]
sockaddr_un_zero_length_input()886     fn sockaddr_un_zero_length_input() {
887         let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
888     }
889 
890     #[test]
sockaddr_un_long_input_err()891     fn sockaddr_un_long_input_err() {
892         let res = sockaddr_un(Path::new(&"a".repeat(108)));
893         assert!(res.is_err());
894     }
895 
896     #[test]
sockaddr_un_long_input_pass()897     fn sockaddr_un_long_input_pass() {
898         let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
899     }
900 
901     #[test]
sockaddr_un_len_check()902     fn sockaddr_un_len_check() {
903         let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
904         assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
905     }
906 
907     #[test]
908     #[allow(clippy::unnecessary_cast)]
909     // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
910     // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
911     #[allow(clippy::char_lit_as_u8)]
sockaddr_un_pass()912     fn sockaddr_un_pass() {
913         let path_size = 50;
914         let (addr, len) =
915             sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
916         assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
917         assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
918 
919         // Check `sun_path` in returned `sockaddr_un`
920         let mut ref_sun_path = [0 as libc::c_char; 108];
921         for path in ref_sun_path.iter_mut().take(path_size) {
922             *path = 'a' as libc::c_char;
923         }
924 
925         for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
926             assert_eq!(addr_char, ref_char);
927         }
928     }
929 }
930