• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     cmp::Ordering,
7     convert::TryFrom,
8     ffi::OsString,
9     fs::remove_file,
10     io,
11     mem::{
12         size_of, {self},
13     },
14     net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs},
15     ops::Deref,
16     os::unix::{
17         ffi::{OsStrExt, OsStringExt},
18         io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
19     },
20     path::{Path, PathBuf},
21     ptr::null_mut,
22     time::{Duration, Instant},
23 };
24 
25 use libc::{
26     c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
27     socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
28 };
29 use serde::{Deserialize, Serialize};
30 
31 use super::{
32     sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT},
33     Error, RawDescriptor,
34 };
35 use crate::descriptor::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor};
36 
37 /// Assist in handling both IP version 4 and IP version 6.
38 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
39 pub enum InetVersion {
40     V4,
41     V6,
42 }
43 
44 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self45     pub fn from_sockaddr(s: &SocketAddr) -> Self {
46         match s {
47             SocketAddr::V4(_) => InetVersion::V4,
48             SocketAddr::V6(_) => InetVersion::V6,
49         }
50     }
51 }
52 
53 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t54     fn from(v: InetVersion) -> sa_family_t {
55         match v {
56             InetVersion::V4 => AF_INET as sa_family_t,
57             InetVersion::V6 => AF_INET6 as sa_family_t,
58         }
59     }
60 }
61 
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in62 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
63     sockaddr_in {
64         sin_family: AF_INET as sa_family_t,
65         sin_port: s.port().to_be(),
66         sin_addr: in_addr {
67             s_addr: u32::from_ne_bytes(s.ip().octets()),
68         },
69         sin_zero: [0; 8],
70     }
71 }
72 
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in673 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
74     sockaddr_in6 {
75         sin6_family: AF_INET6 as sa_family_t,
76         sin6_port: s.port().to_be(),
77         sin6_flowinfo: 0,
78         sin6_addr: in6_addr {
79             s6_addr: s.ip().octets(),
80         },
81         sin6_scope_id: 0,
82     }
83 }
84 
85 /// A TCP socket.
86 ///
87 /// Do not use this class unless you need to change socket options or query the
88 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
89 /// TcpListener.
90 #[derive(Debug)]
91 pub struct TcpSocket {
92     inet_version: InetVersion,
93     fd: RawFd,
94 }
95 
96 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>97     pub fn new(inet_version: InetVersion) -> io::Result<Self> {
98         let fd = unsafe {
99             libc::socket(
100                 Into::<sa_family_t>::into(inet_version) as c_int,
101                 SOCK_STREAM | SOCK_CLOEXEC,
102                 0,
103             )
104         };
105         if fd < 0 {
106             Err(io::Error::last_os_error())
107         } else {
108             Ok(TcpSocket { inet_version, fd })
109         }
110     }
111 
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>112     pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
113         let sockaddr = addr
114             .to_socket_addrs()
115             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
116             .next()
117             .unwrap();
118 
119         let ret = match sockaddr {
120             SocketAddr::V4(a) => {
121                 let sin = sockaddrv4_to_lib_c(&a);
122                 // Safe because this doesn't modify any memory and we check the return value.
123                 unsafe {
124                     libc::bind(
125                         self.fd,
126                         &sin as *const sockaddr_in as *const sockaddr,
127                         size_of::<sockaddr_in>() as socklen_t,
128                     )
129                 }
130             }
131             SocketAddr::V6(a) => {
132                 let sin6 = sockaddrv6_to_lib_c(&a);
133                 // Safe because this doesn't modify any memory and we check the return value.
134                 unsafe {
135                     libc::bind(
136                         self.fd,
137                         &sin6 as *const sockaddr_in6 as *const sockaddr,
138                         size_of::<sockaddr_in6>() as socklen_t,
139                     )
140                 }
141             }
142         };
143         if ret < 0 {
144             let bind_err = io::Error::last_os_error();
145             Err(bind_err)
146         } else {
147             Ok(())
148         }
149     }
150 
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>151     pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
152         let sockaddr = addr
153             .to_socket_addrs()
154             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
155             .next()
156             .unwrap();
157 
158         let ret = match sockaddr {
159             SocketAddr::V4(a) => {
160                 let sin = sockaddrv4_to_lib_c(&a);
161                 // Safe because this doesn't modify any memory and we check the return value.
162                 unsafe {
163                     libc::connect(
164                         self.fd,
165                         &sin as *const sockaddr_in as *const sockaddr,
166                         size_of::<sockaddr_in>() as socklen_t,
167                     )
168                 }
169             }
170             SocketAddr::V6(a) => {
171                 let sin6 = sockaddrv6_to_lib_c(&a);
172                 // Safe because this doesn't modify any memory and we check the return value.
173                 unsafe {
174                     libc::connect(
175                         self.fd,
176                         &sin6 as *const sockaddr_in6 as *const sockaddr,
177                         size_of::<sockaddr_in>() as socklen_t,
178                     )
179                 }
180             }
181         };
182 
183         if ret < 0 {
184             let connect_err = io::Error::last_os_error();
185             Err(connect_err)
186         } else {
187             // Safe because the ownership of the raw fd is released from self and taken over by the
188             // new TcpStream.
189             Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
190         }
191     }
192 
listen(self) -> io::Result<TcpListener>193     pub fn listen(self) -> io::Result<TcpListener> {
194         // Safe because this doesn't modify any memory and we check the return value.
195         let ret = unsafe { libc::listen(self.fd, 1) };
196         if ret < 0 {
197             let listen_err = io::Error::last_os_error();
198             Err(listen_err)
199         } else {
200             // Safe because the ownership of the raw fd is released from self and taken over by the
201             // new TcpListener.
202             Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
203         }
204     }
205 
206     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>207     pub fn local_port(&self) -> io::Result<u16> {
208         match self.inet_version {
209             InetVersion::V4 => {
210                 let mut sin = sockaddr_in {
211                     sin_family: 0,
212                     sin_port: 0,
213                     sin_addr: in_addr { s_addr: 0 },
214                     sin_zero: [0; 8],
215                 };
216 
217                 // Safe because we give a valid pointer for addrlen and check the length.
218                 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
219                 let ret = unsafe {
220                     // Get the socket address that was actually bound.
221                     libc::getsockname(
222                         self.fd,
223                         &mut sin as *mut sockaddr_in as *mut sockaddr,
224                         &mut addrlen as *mut socklen_t,
225                     )
226                 };
227                 if ret < 0 {
228                     let getsockname_err = io::Error::last_os_error();
229                     Err(getsockname_err)
230                 } else {
231                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
232                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
233 
234                     Ok(u16::from_be(sin.sin_port))
235                 }
236             }
237             InetVersion::V6 => {
238                 let mut sin6 = sockaddr_in6 {
239                     sin6_family: 0,
240                     sin6_port: 0,
241                     sin6_flowinfo: 0,
242                     sin6_addr: in6_addr { s6_addr: [0; 16] },
243                     sin6_scope_id: 0,
244                 };
245 
246                 // Safe because we give a valid pointer for addrlen and check the length.
247                 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
248                 let ret = unsafe {
249                     // Get the socket address that was actually bound.
250                     libc::getsockname(
251                         self.fd,
252                         &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
253                         &mut addrlen as *mut socklen_t,
254                     )
255                 };
256                 if ret < 0 {
257                     let getsockname_err = io::Error::last_os_error();
258                     Err(getsockname_err)
259                 } else {
260                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
261                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
262 
263                     Ok(u16::from_be(sin6.sin6_port))
264                 }
265             }
266         }
267     }
268 }
269 
270 impl IntoRawFd for TcpSocket {
into_raw_fd(self) -> RawFd271     fn into_raw_fd(self) -> RawFd {
272         let fd = self.fd;
273         mem::forget(self);
274         fd
275     }
276 }
277 
278 impl AsRawFd for TcpSocket {
as_raw_fd(&self) -> RawFd279     fn as_raw_fd(&self) -> RawFd {
280         self.fd
281     }
282 }
283 
284 impl Drop for TcpSocket {
drop(&mut self)285     fn drop(&mut self) {
286         // Safe because this doesn't modify any memory and we are the only
287         // owner of the file descriptor.
288         unsafe { libc::close(self.fd) };
289     }
290 }
291 
292 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize293 fn sun_path_offset() -> usize {
294     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
295     #[allow(clippy::zero_ptr)]
296     let addr = 0 as *const libc::sockaddr_un;
297     // Safe because we only use the dereference to create a pointer to the desired field in
298     // calculating the offset.
299     unsafe { &(*addr).sun_path as *const _ as usize }
300 }
301 
302 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>303 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
304     let mut addr = libc::sockaddr_un {
305         sun_family: libc::AF_UNIX as libc::sa_family_t,
306         sun_path: [0; 108],
307     };
308 
309     // Check if the input path is valid. Since
310     // * The pathname in sun_path should be null-terminated.
311     // * The length of the pathname, including the terminating null byte,
312     //   should not exceed the size of sun_path.
313     //
314     // and our input is a `Path`, we only need to check
315     // * If the string size of `Path` should less than sizeof(sun_path)
316     // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
317     //
318     // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
319     let bytes = path.as_ref().as_os_str().as_bytes();
320     if bytes.len() >= addr.sun_path.len() {
321         return Err(io::Error::new(
322             io::ErrorKind::InvalidInput,
323             "Input path size should be less than the length of sun_path.",
324         ));
325     };
326 
327     // Copy data from `path` to `addr.sun_path`
328     for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
329         *dst = *src as libc::c_char;
330     }
331 
332     // The addrlen argument that describes the enclosing sockaddr_un structure
333     // should have a value of at least:
334     //
335     //     offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
336     //
337     // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
338     let len = sun_path_offset() + bytes.len() + 1;
339     Ok((addr, len as libc::socklen_t))
340 }
341 
342 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
343 #[derive(Debug, Serialize, Deserialize)]
344 pub struct UnixSeqpacket {
345     #[serde(with = "super::with_raw_descriptor")]
346     fd: RawFd,
347 }
348 
349 impl UnixSeqpacket {
350     /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
351     ///
352     /// # Arguments
353     /// * `path` - Path to `SOCK_SEQPACKET` socket
354     ///
355     /// # Returns
356     /// A `UnixSeqpacket` structure point to the socket
357     ///
358     /// # Errors
359     /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>360     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
361         // Safe socket initialization since we handle the returned error.
362         let fd = unsafe {
363             match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
364                 -1 => return Err(io::Error::last_os_error()),
365                 fd => fd,
366             }
367         };
368 
369         let (addr, len) = sockaddr_un(path.as_ref())?;
370         // Safe connect since we handle the error and use the right length generated from
371         // `sockaddr_un`.
372         unsafe {
373             let ret = libc::connect(fd, &addr as *const _ as *const _, len);
374             if ret < 0 {
375                 return Err(io::Error::last_os_error());
376             }
377         }
378         Ok(UnixSeqpacket { fd })
379     }
380 
381     /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
382     ///
383     /// Both returned file descriptors have the `CLOEXEC` flag set.s
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>384     pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
385         let mut fds = [0, 0];
386         unsafe {
387             // Safe because we give enough space to store all the fds and we check the return value.
388             let ret = libc::socketpair(
389                 libc::AF_UNIX,
390                 libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
391                 0,
392                 &mut fds[0],
393             );
394             if ret == 0 {
395                 Ok((
396                     UnixSeqpacket::from_raw_fd(fds[0]),
397                     UnixSeqpacket::from_raw_fd(fds[1]),
398                 ))
399             } else {
400                 Err(io::Error::last_os_error())
401             }
402         }
403     }
404 
405     /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>406     pub fn try_clone(&self) -> io::Result<Self> {
407         // Safe because this doesn't modify any memory and we check the return value.
408         let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
409         if fd < 0 {
410             Err(io::Error::last_os_error())
411         } else {
412             Ok(Self { fd })
413         }
414     }
415 
416     /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>417     pub fn get_readable_bytes(&self) -> io::Result<usize> {
418         let mut byte_count = 0i32;
419         let ret = unsafe { libc::ioctl(self.fd, libc::FIONREAD, &mut byte_count) };
420         if ret < 0 {
421             Err(io::Error::last_os_error())
422         } else {
423             Ok(byte_count as usize)
424         }
425     }
426 
427     /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
428     /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>429     pub fn next_packet_size(&self) -> io::Result<usize> {
430         #[cfg(not(debug_assertions))]
431         let buf = null_mut();
432         // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
433         // This only matters for running the unit tests for a non-native architecture. See the
434         // upstream thread for the qemu fix:
435         // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
436         #[cfg(debug_assertions)]
437         let buf = &mut 0 as *mut _ as *mut _;
438 
439         // This form of recvfrom doesn't modify any data because all null pointers are used. We only
440         // use the return value and check for errors on an FD owned by this structure.
441         let ret = unsafe {
442             recvfrom(
443                 self.fd,
444                 buf,
445                 0,
446                 MSG_TRUNC | MSG_PEEK,
447                 null_mut(),
448                 null_mut(),
449             )
450         };
451         if ret < 0 {
452             Err(io::Error::last_os_error())
453         } else {
454             Ok(ret as usize)
455         }
456     }
457 
458     /// Write data from a given buffer to the socket fd
459     ///
460     /// # Arguments
461     /// * `buf` - A reference to the data buffer.
462     ///
463     /// # Returns
464     /// * `usize` - The size of bytes written to the buffer.
465     ///
466     /// # Errors
467     /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>468     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
469         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
470         unsafe {
471             let ret = libc::write(self.fd, buf.as_ptr() as *const _, buf.len());
472             if ret < 0 {
473                 Err(io::Error::last_os_error())
474             } else {
475                 Ok(ret as usize)
476             }
477         }
478     }
479 
480     /// Read data from the socket fd to a given buffer
481     ///
482     /// # Arguments
483     /// * `buf` - A mut reference to the data buffer.
484     ///
485     /// # Returns
486     /// * `usize` - The size of bytes read to the buffer.
487     ///
488     /// # Errors
489     /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>490     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
491         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
492         unsafe {
493             let ret = libc::read(self.fd, buf.as_mut_ptr() as *mut _, buf.len());
494             if ret < 0 {
495                 Err(io::Error::last_os_error())
496             } else {
497                 Ok(ret as usize)
498             }
499         }
500     }
501 
502     /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
503     ///
504     /// # Arguments
505     /// * `buf` - A mut reference to a `Vec` to resize and read into.
506     ///
507     /// # Errors
508     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>509     pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
510         let packet_size = self.next_packet_size()?;
511         buf.resize(packet_size, 0);
512         let read_bytes = self.recv(buf)?;
513         buf.resize(read_bytes, 0);
514         Ok(())
515     }
516 
517     /// Read data from the socket fd to a new `Vec`.
518     ///
519     /// # Returns
520     /// * `vec` - A new `Vec` with the entire received packet.
521     ///
522     /// # Errors
523     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>524     pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
525         let mut buf = Vec::new();
526         self.recv_to_vec(&mut buf)?;
527         Ok(buf)
528     }
529 
530     /// Read data and fds from the socket fd to a new pair of `Vec`.
531     ///
532     /// # Returns
533     /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
534     /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
535     ///
536     /// # Errors
537     /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>538     pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
539         let packet_size = self.next_packet_size()?;
540         let mut buf = vec![0; packet_size];
541         let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
542         let (read_bytes, read_fds) =
543             self.recv_with_fds(io::IoSliceMut::new(&mut buf), &mut fd_buf)?;
544         buf.resize(read_bytes, 0);
545         fd_buf.resize(read_fds, -1);
546         Ok((buf, fd_buf))
547     }
548 
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>549     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
550         let timeval = match timeout {
551             Some(t) => {
552                 if t.as_secs() == 0 && t.subsec_micros() == 0 {
553                     return Err(io::Error::new(
554                         io::ErrorKind::InvalidInput,
555                         "zero timeout duration is invalid",
556                     ));
557                 }
558                 // subsec_micros fits in i32 because it is defined to be less than one million.
559                 let nsec = t.subsec_micros() as i32;
560                 libc::timeval {
561                     tv_sec: t.as_secs() as libc::time_t,
562                     tv_usec: libc::suseconds_t::from(nsec),
563                 }
564             }
565             None => libc::timeval {
566                 tv_sec: 0,
567                 tv_usec: 0,
568             },
569         };
570         // Safe because we own the fd, and the length of the pointer's data is the same as the
571         // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
572         // and the return value is checked.
573         let ret = unsafe {
574             libc::setsockopt(
575                 self.fd,
576                 libc::SOL_SOCKET,
577                 kind,
578                 &timeval as *const libc::timeval as *const libc::c_void,
579                 mem::size_of::<libc::timeval>() as libc::socklen_t,
580             )
581         };
582         if ret < 0 {
583             Err(io::Error::last_os_error())
584         } else {
585             Ok(())
586         }
587     }
588 
589     /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>590     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
591         self.set_timeout(timeout, libc::SO_RCVTIMEO)
592     }
593 
594     /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>595     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
596         self.set_timeout(timeout, libc::SO_SNDTIMEO)
597     }
598 }
599 
600 impl Drop for UnixSeqpacket {
drop(&mut self)601     fn drop(&mut self) {
602         // Safe if the UnixSeqpacket is created from Self::connect.
603         unsafe {
604             libc::close(self.fd);
605         }
606     }
607 }
608 
609 impl FromRawFd for UnixSeqpacket {
610     // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self611     unsafe fn from_raw_fd(fd: RawFd) -> Self {
612         Self { fd }
613     }
614 }
615 
616 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self617     unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
618         Self { fd: descriptor }
619     }
620 }
621 
622 impl IntoRawDescriptor for UnixSeqpacket {
into_raw_descriptor(self) -> RawDescriptor623     fn into_raw_descriptor(self) -> RawDescriptor {
624         let fd = self.fd;
625         mem::forget(self);
626         fd
627     }
628 }
629 
630 impl AsRawFd for UnixSeqpacket {
as_raw_fd(&self) -> RawFd631     fn as_raw_fd(&self) -> RawFd {
632         self.fd
633     }
634 }
635 
636 impl AsRawFd for &UnixSeqpacket {
as_raw_fd(&self) -> RawFd637     fn as_raw_fd(&self) -> RawFd {
638         self.fd
639     }
640 }
641 
642 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor643     fn as_raw_descriptor(&self) -> RawDescriptor {
644         self.fd
645     }
646 }
647 
648 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>649     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
650         self.recv(buf)
651     }
652 }
653 
654 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>655     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
656         self.send(buf)
657     }
658 
flush(&mut self) -> io::Result<()>659     fn flush(&mut self) -> io::Result<()> {
660         Ok(())
661     }
662 }
663 
664 impl IntoRawFd for UnixSeqpacket {
into_raw_fd(self) -> RawFd665     fn into_raw_fd(self) -> RawFd {
666         let fd = self.fd;
667         mem::forget(self);
668         fd
669     }
670 }
671 
672 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
673 pub struct UnixSeqpacketListener {
674     fd: RawFd,
675 }
676 
677 impl UnixSeqpacketListener {
678     /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>679     pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
680         // Safe socket initialization since we handle the returned error.
681         let fd = unsafe {
682             match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
683                 -1 => return Err(io::Error::last_os_error()),
684                 fd => fd,
685             }
686         };
687 
688         let (addr, len) = sockaddr_un(path.as_ref())?;
689         // Safe connect since we handle the error and use the right length generated from
690         // `sockaddr_un`.
691         unsafe {
692             let ret = handle_eintr_errno!(libc::bind(fd, &addr as *const _ as *const _, len));
693             if ret < 0 {
694                 return Err(io::Error::last_os_error());
695             }
696             let ret = handle_eintr_errno!(libc::listen(fd, 128));
697             if ret < 0 {
698                 return Err(io::Error::last_os_error());
699             }
700         }
701         Ok(UnixSeqpacketListener { fd })
702     }
703 
704     /// Blocks for and accepts a new incoming connection and returns the socket associated with that
705     /// connection.
706     ///
707     /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>708     pub fn accept(&self) -> io::Result<UnixSeqpacket> {
709         // Safe because we own this fd and the kernel will not write to null pointers.
710         let ret = unsafe { libc::accept4(self.fd, null_mut(), null_mut(), libc::SOCK_CLOEXEC) };
711         if ret < 0 {
712             return Err(io::Error::last_os_error());
713         }
714         // Safe because we checked the return value of accept. Therefore, the return value must be a
715         // valid socket.
716         Ok(unsafe { UnixSeqpacket::from_raw_fd(ret) })
717     }
718 
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>719     pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
720         let start = Instant::now();
721 
722         loop {
723             let mut fds = libc::pollfd {
724                 fd: self.fd,
725                 events: libc::POLLIN,
726                 revents: 0,
727             };
728             let elapsed = Instant::now().saturating_duration_since(start);
729             let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
730             let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
731             // Safe because we give a valid pointer to a list (of 1) FD and we check
732             // the return value.
733             match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
734                 Ordering::Greater => return self.accept(),
735                 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
736                 Ordering::Less => {
737                     if Error::last() != Error::new(libc::EINTR) {
738                         return Err(io::Error::last_os_error());
739                     }
740                 }
741             }
742         }
743     }
744 
745     /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>746     pub fn path(&self) -> io::Result<PathBuf> {
747         let mut addr = libc::sockaddr_un {
748             sun_family: libc::AF_UNIX as libc::sa_family_t,
749             sun_path: [0; 108],
750         };
751         let sun_path_offset = (&addr.sun_path as *const _ as usize
752             - &addr.sun_family as *const _ as usize)
753             as libc::socklen_t;
754         let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
755         // Safe because the length given matches the length of the data of the given pointer, and we
756         // check the return value.
757         let ret = unsafe {
758             handle_eintr_errno!(libc::getsockname(
759                 self.fd,
760                 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
761                 &mut len
762             ))
763         };
764         if ret < 0 {
765             return Err(io::Error::last_os_error());
766         }
767         if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
768             || addr.sun_path[0] == 0
769             || len < 1 + sun_path_offset
770         {
771             return Err(io::Error::new(
772                 io::ErrorKind::InvalidInput,
773                 "getsockname on socket returned invalid value",
774             ));
775         }
776 
777         let path_os_str = OsString::from_vec(
778             addr.sun_path[..(len - sun_path_offset - 1) as usize]
779                 .iter()
780                 .map(|&c| c as _)
781                 .collect(),
782         );
783         Ok(path_os_str.into())
784     }
785 }
786 
787 impl Drop for UnixSeqpacketListener {
drop(&mut self)788     fn drop(&mut self) {
789         // Safe if the UnixSeqpacketListener is created from Self::listen.
790         unsafe {
791             libc::close(self.fd);
792         }
793     }
794 }
795 
796 impl FromRawFd for UnixSeqpacketListener {
797     // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self798     unsafe fn from_raw_fd(fd: RawFd) -> Self {
799         Self { fd }
800     }
801 }
802 
803 impl AsRawFd for UnixSeqpacketListener {
as_raw_fd(&self) -> RawFd804     fn as_raw_fd(&self) -> RawFd {
805         self.fd
806     }
807 }
808 
809 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
810 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
811 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener812     fn as_ref(&self) -> &UnixSeqpacketListener {
813         &self.0
814     }
815 }
816 
817 impl AsRawFd for UnlinkUnixSeqpacketListener {
as_raw_fd(&self) -> RawFd818     fn as_raw_fd(&self) -> RawFd {
819         self.0.as_raw_fd()
820     }
821 }
822 
823 impl Deref for UnlinkUnixSeqpacketListener {
824     type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target825     fn deref(&self) -> &Self::Target {
826         &self.0
827     }
828 }
829 
830 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)831     fn drop(&mut self) {
832         if let Ok(path) = self.0.path() {
833             if let Err(e) = remove_file(path) {
834                 warn!("failed to remove control socket file: {:?}", e);
835             }
836         }
837     }
838 }
839 
840 #[cfg(test)]
841 mod tests {
842     use super::*;
843     use std::{env, io::ErrorKind, path::PathBuf};
844 
tmpdir() -> PathBuf845     fn tmpdir() -> PathBuf {
846         env::temp_dir()
847     }
848 
849     #[test]
sockaddr_un_zero_length_input()850     fn sockaddr_un_zero_length_input() {
851         let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
852     }
853 
854     #[test]
sockaddr_un_long_input_err()855     fn sockaddr_un_long_input_err() {
856         let res = sockaddr_un(Path::new(&"a".repeat(108)));
857         assert!(res.is_err());
858     }
859 
860     #[test]
sockaddr_un_long_input_pass()861     fn sockaddr_un_long_input_pass() {
862         let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
863     }
864 
865     #[test]
sockaddr_un_len_check()866     fn sockaddr_un_len_check() {
867         let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
868         assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
869     }
870 
871     #[test]
872     #[allow(clippy::unnecessary_cast)]
sockaddr_un_pass()873     fn sockaddr_un_pass() {
874         let path_size = 50;
875         let (addr, len) =
876             sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
877         assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
878         assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
879 
880         // Check `sun_path` in returned `sockaddr_un`
881         let mut ref_sun_path = [0 as libc::c_char; 108];
882         for path in ref_sun_path.iter_mut().take(path_size) {
883             *path = 'a' as libc::c_char;
884         }
885 
886         for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
887             assert_eq!(addr_char, ref_char);
888         }
889     }
890 
891     #[test]
unix_seqpacket_path_not_exists()892     fn unix_seqpacket_path_not_exists() {
893         let res = UnixSeqpacket::connect("/path/not/exists");
894         assert!(res.is_err());
895     }
896 
897     #[test]
unix_seqpacket_listener_path()898     fn unix_seqpacket_listener_path() {
899         let mut socket_path = tmpdir();
900         socket_path.push("unix_seqpacket_listener_path");
901         let listener = UnlinkUnixSeqpacketListener(
902             UnixSeqpacketListener::bind(&socket_path)
903                 .expect("failed to create UnixSeqpacketListener"),
904         );
905         let listener_path = listener.path().expect("failed to get socket listener path");
906         assert_eq!(socket_path, listener_path);
907     }
908 
909     #[test]
unix_seqpacket_path_exists_pass()910     fn unix_seqpacket_path_exists_pass() {
911         let mut socket_path = tmpdir();
912         socket_path.push("path_to_socket");
913         let _listener = UnlinkUnixSeqpacketListener(
914             UnixSeqpacketListener::bind(&socket_path)
915                 .expect("failed to create UnixSeqpacketListener"),
916         );
917         let _res =
918             UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
919     }
920 
921     #[test]
unix_seqpacket_path_listener_accept_with_timeout()922     fn unix_seqpacket_path_listener_accept_with_timeout() {
923         let mut socket_path = tmpdir();
924         socket_path.push("path_listerner_accept_with_timeout");
925         let listener = UnlinkUnixSeqpacketListener(
926             UnixSeqpacketListener::bind(&socket_path)
927                 .expect("failed to create UnixSeqpacketListener"),
928         );
929 
930         for d in [Duration::from_millis(10), Duration::ZERO] {
931             let _ = listener.accept_with_timeout(d).expect_err(&format!(
932                 "UnixSeqpacket::accept_with_timeout {:?} connected",
933                 d
934             ));
935 
936             let s1 = UnixSeqpacket::connect(socket_path.as_path())
937                 .unwrap_or_else(|_| panic!("UnixSeqpacket::connect {:?} failed", d));
938 
939             let s2 = listener
940                 .accept_with_timeout(d)
941                 .unwrap_or_else(|_| panic!("UnixSeqpacket::accept {:?} failed", d));
942 
943             let data1 = &[0, 1, 2, 3, 4];
944             let data2 = &[10, 11, 12, 13, 14];
945             s2.send(data2).expect("failed to send data2");
946             s1.send(data1).expect("failed to send data1");
947             let recv_data = &mut [0; 5];
948             s2.recv(recv_data).expect("failed to recv data");
949             assert_eq!(data1, recv_data);
950             s1.recv(recv_data).expect("failed to recv data");
951             assert_eq!(data2, recv_data);
952         }
953     }
954 
955     #[test]
unix_seqpacket_path_listener_accept()956     fn unix_seqpacket_path_listener_accept() {
957         let mut socket_path = tmpdir();
958         socket_path.push("path_listerner_accept");
959         let listener = UnlinkUnixSeqpacketListener(
960             UnixSeqpacketListener::bind(&socket_path)
961                 .expect("failed to create UnixSeqpacketListener"),
962         );
963         let s1 =
964             UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
965 
966         let s2 = listener.accept().expect("UnixSeqpacket::accept failed");
967 
968         let data1 = &[0, 1, 2, 3, 4];
969         let data2 = &[10, 11, 12, 13, 14];
970         s2.send(data2).expect("failed to send data2");
971         s1.send(data1).expect("failed to send data1");
972         let recv_data = &mut [0; 5];
973         s2.recv(recv_data).expect("failed to recv data");
974         assert_eq!(data1, recv_data);
975         s1.recv(recv_data).expect("failed to recv data");
976         assert_eq!(data2, recv_data);
977     }
978 
979     #[test]
unix_seqpacket_zero_timeout()980     fn unix_seqpacket_zero_timeout() {
981         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
982         // Timeouts less than a microsecond are too small and round to zero.
983         s1.set_read_timeout(Some(Duration::from_nanos(10)))
984             .expect_err("successfully set zero timeout");
985     }
986 
987     #[test]
unix_seqpacket_read_timeout()988     fn unix_seqpacket_read_timeout() {
989         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
990         s1.set_read_timeout(Some(Duration::from_millis(1)))
991             .expect("failed to set read timeout for socket");
992         let _ = s1.recv(&mut [0]);
993     }
994 
995     #[test]
unix_seqpacket_write_timeout()996     fn unix_seqpacket_write_timeout() {
997         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
998         s1.set_write_timeout(Some(Duration::from_millis(1)))
999             .expect("failed to set write timeout for socket");
1000     }
1001 
1002     #[test]
unix_seqpacket_send_recv()1003     fn unix_seqpacket_send_recv() {
1004         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1005         let data1 = &[0, 1, 2, 3, 4];
1006         let data2 = &[10, 11, 12, 13, 14];
1007         s2.send(data2).expect("failed to send data2");
1008         s1.send(data1).expect("failed to send data1");
1009         let recv_data = &mut [0; 5];
1010         s2.recv(recv_data).expect("failed to recv data");
1011         assert_eq!(data1, recv_data);
1012         s1.recv(recv_data).expect("failed to recv data");
1013         assert_eq!(data2, recv_data);
1014     }
1015 
1016     #[test]
unix_seqpacket_send_fragments()1017     fn unix_seqpacket_send_fragments() {
1018         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1019         let data1 = &[0, 1, 2, 3, 4];
1020         let data2 = &[10, 11, 12, 13, 14, 15, 16];
1021         s1.send(data1).expect("failed to send data1");
1022         s1.send(data2).expect("failed to send data2");
1023 
1024         let recv_data = &mut [0; 32];
1025         let size = s2.recv(recv_data).expect("failed to recv data");
1026         assert_eq!(size, data1.len());
1027         assert_eq!(data1, &recv_data[0..size]);
1028 
1029         let size = s2.recv(recv_data).expect("failed to recv data");
1030         assert_eq!(size, data2.len());
1031         assert_eq!(data2, &recv_data[0..size]);
1032     }
1033 
1034     #[test]
unix_seqpacket_get_readable_bytes()1035     fn unix_seqpacket_get_readable_bytes() {
1036         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1037         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1038         assert_eq!(s2.get_readable_bytes().unwrap(), 0);
1039         let data1 = &[0, 1, 2, 3, 4];
1040         s1.send(data1).expect("failed to send data");
1041 
1042         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1043         assert_eq!(s2.get_readable_bytes().unwrap(), data1.len());
1044 
1045         let recv_data = &mut [0; 5];
1046         s2.recv(recv_data).expect("failed to recv data");
1047         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1048         assert_eq!(s2.get_readable_bytes().unwrap(), 0);
1049     }
1050 
1051     #[test]
unix_seqpacket_next_packet_size()1052     fn unix_seqpacket_next_packet_size() {
1053         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1054         let data1 = &[0, 1, 2, 3, 4];
1055         s1.send(data1).expect("failed to send data");
1056 
1057         assert_eq!(s2.next_packet_size().unwrap(), 5);
1058         s1.set_read_timeout(Some(Duration::from_micros(1)))
1059             .expect("failed to set read timeout");
1060         assert_eq!(
1061             s1.next_packet_size().unwrap_err().kind(),
1062             ErrorKind::WouldBlock
1063         );
1064         drop(s2);
1065         assert_eq!(
1066             s1.next_packet_size().unwrap_err().kind(),
1067             ErrorKind::ConnectionReset
1068         );
1069     }
1070 
1071     #[test]
unix_seqpacket_recv_to_vec()1072     fn unix_seqpacket_recv_to_vec() {
1073         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1074         let data1 = &[0, 1, 2, 3, 4];
1075         s1.send(data1).expect("failed to send data");
1076 
1077         let recv_data = &mut vec![];
1078         s2.recv_to_vec(recv_data).expect("failed to recv data");
1079         assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
1080     }
1081 
1082     #[test]
unix_seqpacket_recv_as_vec()1083     fn unix_seqpacket_recv_as_vec() {
1084         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1085         let data1 = &[0, 1, 2, 3, 4];
1086         s1.send(data1).expect("failed to send data");
1087 
1088         let recv_data = s2.recv_as_vec().expect("failed to recv data");
1089         assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
1090     }
1091 }
1092