• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::ffi::OsString;
6 use std::fs::remove_file;
7 use std::io;
8 use std::mem::{self, size_of};
9 use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
10 use std::ops::Deref;
11 use std::os::unix::{
12     ffi::{OsStrExt, OsStringExt},
13     io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
14 };
15 use std::path::Path;
16 use std::path::PathBuf;
17 use std::ptr::null_mut;
18 use std::time::Duration;
19 
20 use libc::{
21     c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
22     socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
23 };
24 use serde::{Deserialize, Serialize};
25 
26 use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
27 use crate::{AsRawDescriptor, RawDescriptor};
28 
29 /// Assist in handling both IP version 4 and IP version 6.
30 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
31 pub enum InetVersion {
32     V4,
33     V6,
34 }
35 
36 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self37     pub fn from_sockaddr(s: &SocketAddr) -> Self {
38         match s {
39             SocketAddr::V4(_) => InetVersion::V4,
40             SocketAddr::V6(_) => InetVersion::V6,
41         }
42     }
43 }
44 
45 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t46     fn from(v: InetVersion) -> sa_family_t {
47         match v {
48             InetVersion::V4 => AF_INET as sa_family_t,
49             InetVersion::V6 => AF_INET6 as sa_family_t,
50         }
51     }
52 }
53 
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in54 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
55     sockaddr_in {
56         sin_family: AF_INET as sa_family_t,
57         sin_port: s.port().to_be(),
58         sin_addr: in_addr {
59             s_addr: u32::from_ne_bytes(s.ip().octets()),
60         },
61         sin_zero: [0; 8],
62     }
63 }
64 
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in665 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
66     sockaddr_in6 {
67         sin6_family: AF_INET6 as sa_family_t,
68         sin6_port: s.port().to_be(),
69         sin6_flowinfo: 0,
70         sin6_addr: in6_addr {
71             s6_addr: s.ip().octets(),
72         },
73         sin6_scope_id: 0,
74     }
75 }
76 
77 /// A TCP socket.
78 ///
79 /// Do not use this class unless you need to change socket options or query the
80 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
81 /// TcpListener.
82 #[derive(Debug)]
83 pub struct TcpSocket {
84     inet_version: InetVersion,
85     fd: RawFd,
86 }
87 
88 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>89     pub fn new(inet_version: InetVersion) -> io::Result<Self> {
90         let fd = unsafe {
91             libc::socket(
92                 Into::<sa_family_t>::into(inet_version) as c_int,
93                 SOCK_STREAM | SOCK_CLOEXEC,
94                 0,
95             )
96         };
97         if fd < 0 {
98             Err(io::Error::last_os_error())
99         } else {
100             Ok(TcpSocket { inet_version, fd })
101         }
102     }
103 
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>104     pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
105         let sockaddr = addr
106             .to_socket_addrs()
107             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
108             .next()
109             .unwrap();
110 
111         let ret = match sockaddr {
112             SocketAddr::V4(a) => {
113                 let sin = sockaddrv4_to_lib_c(&a);
114                 // Safe because this doesn't modify any memory and we check the return value.
115                 unsafe {
116                     libc::bind(
117                         self.fd,
118                         &sin as *const sockaddr_in as *const sockaddr,
119                         size_of::<sockaddr_in>() as socklen_t,
120                     )
121                 }
122             }
123             SocketAddr::V6(a) => {
124                 let sin6 = sockaddrv6_to_lib_c(&a);
125                 // Safe because this doesn't modify any memory and we check the return value.
126                 unsafe {
127                     libc::bind(
128                         self.fd,
129                         &sin6 as *const sockaddr_in6 as *const sockaddr,
130                         size_of::<sockaddr_in6>() as socklen_t,
131                     )
132                 }
133             }
134         };
135         if ret < 0 {
136             let bind_err = io::Error::last_os_error();
137             Err(bind_err)
138         } else {
139             Ok(())
140         }
141     }
142 
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>143     pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
144         let sockaddr = addr
145             .to_socket_addrs()
146             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
147             .next()
148             .unwrap();
149 
150         let ret = match sockaddr {
151             SocketAddr::V4(a) => {
152                 let sin = sockaddrv4_to_lib_c(&a);
153                 // Safe because this doesn't modify any memory and we check the return value.
154                 unsafe {
155                     libc::connect(
156                         self.fd,
157                         &sin as *const sockaddr_in as *const sockaddr,
158                         size_of::<sockaddr_in>() as socklen_t,
159                     )
160                 }
161             }
162             SocketAddr::V6(a) => {
163                 let sin6 = sockaddrv6_to_lib_c(&a);
164                 // Safe because this doesn't modify any memory and we check the return value.
165                 unsafe {
166                     libc::connect(
167                         self.fd,
168                         &sin6 as *const sockaddr_in6 as *const sockaddr,
169                         size_of::<sockaddr_in>() as socklen_t,
170                     )
171                 }
172             }
173         };
174 
175         if ret < 0 {
176             let connect_err = io::Error::last_os_error();
177             Err(connect_err)
178         } else {
179             // Safe because the ownership of the raw fd is released from self and taken over by the
180             // new TcpStream.
181             Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
182         }
183     }
184 
listen(self) -> io::Result<TcpListener>185     pub fn listen(self) -> io::Result<TcpListener> {
186         // Safe because this doesn't modify any memory and we check the return value.
187         let ret = unsafe { libc::listen(self.fd, 1) };
188         if ret < 0 {
189             let listen_err = io::Error::last_os_error();
190             Err(listen_err)
191         } else {
192             // Safe because the ownership of the raw fd is released from self and taken over by the
193             // new TcpListener.
194             Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
195         }
196     }
197 
198     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>199     pub fn local_port(&self) -> io::Result<u16> {
200         match self.inet_version {
201             InetVersion::V4 => {
202                 let mut sin = sockaddr_in {
203                     sin_family: 0,
204                     sin_port: 0,
205                     sin_addr: in_addr { s_addr: 0 },
206                     sin_zero: [0; 8],
207                 };
208 
209                 // Safe because we give a valid pointer for addrlen and check the length.
210                 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
211                 let ret = unsafe {
212                     // Get the socket address that was actually bound.
213                     libc::getsockname(
214                         self.fd,
215                         &mut sin as *mut sockaddr_in as *mut sockaddr,
216                         &mut addrlen as *mut socklen_t,
217                     )
218                 };
219                 if ret < 0 {
220                     let getsockname_err = io::Error::last_os_error();
221                     Err(getsockname_err)
222                 } else {
223                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
224                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
225 
226                     Ok(u16::from_be(sin.sin_port))
227                 }
228             }
229             InetVersion::V6 => {
230                 let mut sin6 = sockaddr_in6 {
231                     sin6_family: 0,
232                     sin6_port: 0,
233                     sin6_flowinfo: 0,
234                     sin6_addr: in6_addr { s6_addr: [0; 16] },
235                     sin6_scope_id: 0,
236                 };
237 
238                 // Safe because we give a valid pointer for addrlen and check the length.
239                 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
240                 let ret = unsafe {
241                     // Get the socket address that was actually bound.
242                     libc::getsockname(
243                         self.fd,
244                         &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
245                         &mut addrlen as *mut socklen_t,
246                     )
247                 };
248                 if ret < 0 {
249                     let getsockname_err = io::Error::last_os_error();
250                     Err(getsockname_err)
251                 } else {
252                     // If this doesn't match, it's not safe to get the port out of the sockaddr.
253                     assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
254 
255                     Ok(u16::from_be(sin6.sin6_port))
256                 }
257             }
258         }
259     }
260 }
261 
262 impl IntoRawFd for TcpSocket {
into_raw_fd(self) -> RawFd263     fn into_raw_fd(self) -> RawFd {
264         let fd = self.fd;
265         mem::forget(self);
266         fd
267     }
268 }
269 
270 impl AsRawFd for TcpSocket {
as_raw_fd(&self) -> RawFd271     fn as_raw_fd(&self) -> RawFd {
272         self.fd
273     }
274 }
275 
276 impl Drop for TcpSocket {
drop(&mut self)277     fn drop(&mut self) {
278         // Safe because this doesn't modify any memory and we are the only
279         // owner of the file descriptor.
280         unsafe { libc::close(self.fd) };
281     }
282 }
283 
284 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize285 fn sun_path_offset() -> usize {
286     // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
287     #[allow(clippy::zero_ptr)]
288     let addr = 0 as *const libc::sockaddr_un;
289     // Safe because we only use the dereference to create a pointer to the desired field in
290     // calculating the offset.
291     unsafe { &(*addr).sun_path as *const _ as usize }
292 }
293 
294 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>295 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
296     let mut addr = libc::sockaddr_un {
297         sun_family: libc::AF_UNIX as libc::sa_family_t,
298         sun_path: [0; 108],
299     };
300 
301     // Check if the input path is valid. Since
302     // * The pathname in sun_path should be null-terminated.
303     // * The length of the pathname, including the terminating null byte,
304     //   should not exceed the size of sun_path.
305     //
306     // and our input is a `Path`, we only need to check
307     // * If the string size of `Path` should less than sizeof(sun_path)
308     // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
309     //
310     // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
311     let bytes = path.as_ref().as_os_str().as_bytes();
312     if bytes.len() >= addr.sun_path.len() {
313         return Err(io::Error::new(
314             io::ErrorKind::InvalidInput,
315             "Input path size should be less than the length of sun_path.",
316         ));
317     };
318 
319     // Copy data from `path` to `addr.sun_path`
320     for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
321         *dst = *src as libc::c_char;
322     }
323 
324     // The addrlen argument that describes the enclosing sockaddr_un structure
325     // should have a value of at least:
326     //
327     //     offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
328     //
329     // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
330     let len = sun_path_offset() + bytes.len() + 1;
331     Ok((addr, len as libc::socklen_t))
332 }
333 
334 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
335 #[derive(Serialize, Deserialize)]
336 pub struct UnixSeqpacket {
337     #[serde(with = "crate::with_raw_descriptor")]
338     fd: RawFd,
339 }
340 
341 impl UnixSeqpacket {
342     /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
343     ///
344     /// # Arguments
345     /// * `path` - Path to `SOCK_SEQPACKET` socket
346     ///
347     /// # Returns
348     /// A `UnixSeqpacket` structure point to the socket
349     ///
350     /// # Errors
351     /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>352     pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
353         // Safe socket initialization since we handle the returned error.
354         let fd = unsafe {
355             match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
356                 -1 => return Err(io::Error::last_os_error()),
357                 fd => fd,
358             }
359         };
360 
361         let (addr, len) = sockaddr_un(path.as_ref())?;
362         // Safe connect since we handle the error and use the right length generated from
363         // `sockaddr_un`.
364         unsafe {
365             let ret = libc::connect(fd, &addr as *const _ as *const _, len);
366             if ret < 0 {
367                 return Err(io::Error::last_os_error());
368             }
369         }
370         Ok(UnixSeqpacket { fd })
371     }
372 
373     /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
374     ///
375     /// Both returned file descriptors have the `CLOEXEC` flag set.s
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>376     pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
377         let mut fds = [0, 0];
378         unsafe {
379             // Safe because we give enough space to store all the fds and we check the return value.
380             let ret = libc::socketpair(
381                 libc::AF_UNIX,
382                 libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
383                 0,
384                 &mut fds[0],
385             );
386             if ret == 0 {
387                 Ok((
388                     UnixSeqpacket::from_raw_fd(fds[0]),
389                     UnixSeqpacket::from_raw_fd(fds[1]),
390                 ))
391             } else {
392                 Err(io::Error::last_os_error())
393             }
394         }
395     }
396 
397     /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>398     pub fn try_clone(&self) -> io::Result<Self> {
399         // Safe because this doesn't modify any memory and we check the return value.
400         let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
401         if fd < 0 {
402             Err(io::Error::last_os_error())
403         } else {
404             Ok(Self { fd })
405         }
406     }
407 
408     /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>409     pub fn get_readable_bytes(&self) -> io::Result<usize> {
410         let mut byte_count = 0i32;
411         let ret = unsafe { libc::ioctl(self.fd, libc::FIONREAD, &mut byte_count) };
412         if ret < 0 {
413             Err(io::Error::last_os_error())
414         } else {
415             Ok(byte_count as usize)
416         }
417     }
418 
419     /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
420     /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>421     pub fn next_packet_size(&self) -> io::Result<usize> {
422         #[cfg(not(debug_assertions))]
423         let buf = null_mut();
424         // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
425         // This only matters for running the unit tests for a non-native architecture. See the
426         // upstream thread for the qemu fix:
427         // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
428         #[cfg(debug_assertions)]
429         let buf = &mut 0 as *mut _ as *mut _;
430 
431         // This form of recvfrom doesn't modify any data because all null pointers are used. We only
432         // use the return value and check for errors on an FD owned by this structure.
433         let ret = unsafe {
434             recvfrom(
435                 self.fd,
436                 buf,
437                 0,
438                 MSG_TRUNC | MSG_PEEK,
439                 null_mut(),
440                 null_mut(),
441             )
442         };
443         if ret < 0 {
444             Err(io::Error::last_os_error())
445         } else {
446             Ok(ret as usize)
447         }
448     }
449 
450     /// Write data from a given buffer to the socket fd
451     ///
452     /// # Arguments
453     /// * `buf` - A reference to the data buffer.
454     ///
455     /// # Returns
456     /// * `usize` - The size of bytes written to the buffer.
457     ///
458     /// # Errors
459     /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>460     pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
461         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
462         unsafe {
463             let ret = libc::write(self.fd, buf.as_ptr() as *const _, buf.len());
464             if ret < 0 {
465                 Err(io::Error::last_os_error())
466             } else {
467                 Ok(ret as usize)
468             }
469         }
470     }
471 
472     /// Read data from the socket fd to a given buffer
473     ///
474     /// # Arguments
475     /// * `buf` - A mut reference to the data buffer.
476     ///
477     /// # Returns
478     /// * `usize` - The size of bytes read to the buffer.
479     ///
480     /// # Errors
481     /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>482     pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
483         // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
484         unsafe {
485             let ret = libc::read(self.fd, buf.as_mut_ptr() as *mut _, buf.len());
486             if ret < 0 {
487                 Err(io::Error::last_os_error())
488             } else {
489                 Ok(ret as usize)
490             }
491         }
492     }
493 
494     /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
495     ///
496     /// # Arguments
497     /// * `buf` - A mut reference to a `Vec` to resize and read into.
498     ///
499     /// # Errors
500     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>501     pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
502         let packet_size = self.next_packet_size()?;
503         buf.resize(packet_size, 0);
504         let read_bytes = self.recv(buf)?;
505         buf.resize(read_bytes, 0);
506         Ok(())
507     }
508 
509     /// Read data from the socket fd to a new `Vec`.
510     ///
511     /// # Returns
512     /// * `vec` - A new `Vec` with the entire received packet.
513     ///
514     /// # Errors
515     /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>516     pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
517         let mut buf = Vec::new();
518         self.recv_to_vec(&mut buf)?;
519         Ok(buf)
520     }
521 
522     /// Read data and fds from the socket fd to a new pair of `Vec`.
523     ///
524     /// # Returns
525     /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
526     /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
527     ///
528     /// # Errors
529     /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>530     pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
531         let packet_size = self.next_packet_size()?;
532         let mut buf = vec![0; packet_size];
533         let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
534         let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?;
535         buf.resize(read_bytes, 0);
536         fd_buf.resize(read_fds, -1);
537         Ok((buf, fd_buf))
538     }
539 
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>540     fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
541         let timeval = match timeout {
542             Some(t) => {
543                 if t.as_secs() == 0 && t.subsec_micros() == 0 {
544                     return Err(io::Error::new(
545                         io::ErrorKind::InvalidInput,
546                         "zero timeout duration is invalid",
547                     ));
548                 }
549                 // subsec_micros fits in i32 because it is defined to be less than one million.
550                 let nsec = t.subsec_micros() as i32;
551                 libc::timeval {
552                     tv_sec: t.as_secs() as libc::time_t,
553                     tv_usec: libc::suseconds_t::from(nsec),
554                 }
555             }
556             None => libc::timeval {
557                 tv_sec: 0,
558                 tv_usec: 0,
559             },
560         };
561         // Safe because we own the fd, and the length of the pointer's data is the same as the
562         // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
563         // and the return value is checked.
564         let ret = unsafe {
565             libc::setsockopt(
566                 self.fd,
567                 libc::SOL_SOCKET,
568                 kind,
569                 &timeval as *const libc::timeval as *const libc::c_void,
570                 mem::size_of::<libc::timeval>() as libc::socklen_t,
571             )
572         };
573         if ret < 0 {
574             Err(io::Error::last_os_error())
575         } else {
576             Ok(())
577         }
578     }
579 
580     /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>581     pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
582         self.set_timeout(timeout, libc::SO_RCVTIMEO)
583     }
584 
585     /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>586     pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
587         self.set_timeout(timeout, libc::SO_SNDTIMEO)
588     }
589 }
590 
591 impl Drop for UnixSeqpacket {
drop(&mut self)592     fn drop(&mut self) {
593         // Safe if the UnixSeqpacket is created from Self::connect.
594         unsafe {
595             libc::close(self.fd);
596         }
597     }
598 }
599 
600 impl FromRawFd for UnixSeqpacket {
601     // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self602     unsafe fn from_raw_fd(fd: RawFd) -> Self {
603         Self { fd }
604     }
605 }
606 
607 impl AsRawFd for UnixSeqpacket {
as_raw_fd(&self) -> RawFd608     fn as_raw_fd(&self) -> RawFd {
609         self.fd
610     }
611 }
612 
613 impl AsRawFd for &UnixSeqpacket {
as_raw_fd(&self) -> RawFd614     fn as_raw_fd(&self) -> RawFd {
615         self.fd
616     }
617 }
618 
619 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor620     fn as_raw_descriptor(&self) -> RawDescriptor {
621         self.fd
622     }
623 }
624 
625 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
626 pub struct UnixSeqpacketListener {
627     fd: RawFd,
628 }
629 
630 impl UnixSeqpacketListener {
631     /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>632     pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
633         // Safe socket initialization since we handle the returned error.
634         let fd = unsafe {
635             match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
636                 -1 => return Err(io::Error::last_os_error()),
637                 fd => fd,
638             }
639         };
640 
641         let (addr, len) = sockaddr_un(path.as_ref())?;
642         // Safe connect since we handle the error and use the right length generated from
643         // `sockaddr_un`.
644         unsafe {
645             let ret = handle_eintr_errno!(libc::bind(fd, &addr as *const _ as *const _, len));
646             if ret < 0 {
647                 return Err(io::Error::last_os_error());
648             }
649             let ret = handle_eintr_errno!(libc::listen(fd, 128));
650             if ret < 0 {
651                 return Err(io::Error::last_os_error());
652             }
653         }
654         Ok(UnixSeqpacketListener { fd })
655     }
656 
657     /// Blocks for and accepts a new incoming connection and returns the socket associated with that
658     /// connection.
659     ///
660     /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>661     pub fn accept(&self) -> io::Result<UnixSeqpacket> {
662         // Safe because we own this fd and the kernel will not write to null pointers.
663         let ret = unsafe { libc::accept4(self.fd, null_mut(), null_mut(), libc::SOCK_CLOEXEC) };
664         if ret < 0 {
665             return Err(io::Error::last_os_error());
666         }
667         // Safe because we checked the return value of accept. Therefore, the return value must be a
668         // valid socket.
669         Ok(unsafe { UnixSeqpacket::from_raw_fd(ret) })
670     }
671 
672     /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>673     pub fn path(&self) -> io::Result<PathBuf> {
674         let mut addr = libc::sockaddr_un {
675             sun_family: libc::AF_UNIX as libc::sa_family_t,
676             sun_path: [0; 108],
677         };
678         let sun_path_offset = (&addr.sun_path as *const _ as usize
679             - &addr.sun_family as *const _ as usize)
680             as libc::socklen_t;
681         let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
682         // Safe because the length given matches the length of the data of the given pointer, and we
683         // check the return value.
684         let ret = unsafe {
685             handle_eintr_errno!(libc::getsockname(
686                 self.fd,
687                 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
688                 &mut len
689             ))
690         };
691         if ret < 0 {
692             return Err(io::Error::last_os_error());
693         }
694         if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
695             || addr.sun_path[0] == 0
696             || len < 1 + sun_path_offset
697         {
698             return Err(io::Error::new(
699                 io::ErrorKind::InvalidInput,
700                 "getsockname on socket returned invalid value",
701             ));
702         }
703 
704         let path_os_str = OsString::from_vec(
705             addr.sun_path[..(len - sun_path_offset - 1) as usize]
706                 .iter()
707                 .map(|&c| c as _)
708                 .collect(),
709         );
710         Ok(path_os_str.into())
711     }
712 }
713 
714 impl Drop for UnixSeqpacketListener {
drop(&mut self)715     fn drop(&mut self) {
716         // Safe if the UnixSeqpacketListener is created from Self::listen.
717         unsafe {
718             libc::close(self.fd);
719         }
720     }
721 }
722 
723 impl FromRawFd for UnixSeqpacketListener {
724     // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self725     unsafe fn from_raw_fd(fd: RawFd) -> Self {
726         Self { fd }
727     }
728 }
729 
730 impl AsRawFd for UnixSeqpacketListener {
as_raw_fd(&self) -> RawFd731     fn as_raw_fd(&self) -> RawFd {
732         self.fd
733     }
734 }
735 
736 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
737 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
738 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener739     fn as_ref(&self) -> &UnixSeqpacketListener {
740         &self.0
741     }
742 }
743 
744 impl AsRawFd for UnlinkUnixSeqpacketListener {
as_raw_fd(&self) -> RawFd745     fn as_raw_fd(&self) -> RawFd {
746         self.0.as_raw_fd()
747     }
748 }
749 
750 impl Deref for UnlinkUnixSeqpacketListener {
751     type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target752     fn deref(&self) -> &Self::Target {
753         &self.0
754     }
755 }
756 
757 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)758     fn drop(&mut self) {
759         if let Ok(path) = self.0.path() {
760             if let Err(e) = remove_file(path) {
761                 warn!("failed to remove control socket file: {:?}", e);
762             }
763         }
764     }
765 }
766 
767 #[cfg(test)]
768 mod tests {
769     use super::*;
770     use std::env;
771     use std::io::ErrorKind;
772     use std::path::PathBuf;
773 
tmpdir() -> PathBuf774     fn tmpdir() -> PathBuf {
775         env::temp_dir()
776     }
777 
778     #[test]
sockaddr_un_zero_length_input()779     fn sockaddr_un_zero_length_input() {
780         let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
781     }
782 
783     #[test]
sockaddr_un_long_input_err()784     fn sockaddr_un_long_input_err() {
785         let res = sockaddr_un(Path::new(&"a".repeat(108)));
786         assert!(res.is_err());
787     }
788 
789     #[test]
sockaddr_un_long_input_pass()790     fn sockaddr_un_long_input_pass() {
791         let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
792     }
793 
794     #[test]
sockaddr_un_len_check()795     fn sockaddr_un_len_check() {
796         let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
797         assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
798     }
799 
800     #[test]
sockaddr_un_pass()801     fn sockaddr_un_pass() {
802         let path_size = 50;
803         let (addr, len) =
804             sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
805         assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
806         assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
807 
808         // Check `sun_path` in returned `sockaddr_un`
809         let mut ref_sun_path = [0 as libc::c_char; 108];
810         for path in ref_sun_path.iter_mut().take(path_size) {
811             *path = 'a' as libc::c_char;
812         }
813 
814         for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
815             assert_eq!(addr_char, ref_char);
816         }
817     }
818 
819     #[test]
unix_seqpacket_path_not_exists()820     fn unix_seqpacket_path_not_exists() {
821         let res = UnixSeqpacket::connect("/path/not/exists");
822         assert!(res.is_err());
823     }
824 
825     #[test]
unix_seqpacket_listener_path()826     fn unix_seqpacket_listener_path() {
827         let mut socket_path = tmpdir();
828         socket_path.push("unix_seqpacket_listener_path");
829         let listener = UnlinkUnixSeqpacketListener(
830             UnixSeqpacketListener::bind(&socket_path)
831                 .expect("failed to create UnixSeqpacketListener"),
832         );
833         let listener_path = listener.path().expect("failed to get socket listener path");
834         assert_eq!(socket_path, listener_path);
835     }
836 
837     #[test]
unix_seqpacket_path_exists_pass()838     fn unix_seqpacket_path_exists_pass() {
839         let mut socket_path = tmpdir();
840         socket_path.push("path_to_socket");
841         let _listener = UnlinkUnixSeqpacketListener(
842             UnixSeqpacketListener::bind(&socket_path)
843                 .expect("failed to create UnixSeqpacketListener"),
844         );
845         let _res =
846             UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
847     }
848 
849     #[test]
unix_seqpacket_path_listener_accept()850     fn unix_seqpacket_path_listener_accept() {
851         let mut socket_path = tmpdir();
852         socket_path.push("path_listerner_accept");
853         let listener = UnlinkUnixSeqpacketListener(
854             UnixSeqpacketListener::bind(&socket_path)
855                 .expect("failed to create UnixSeqpacketListener"),
856         );
857         let s1 =
858             UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
859 
860         let s2 = listener.accept().expect("UnixSeqpacket::accept failed");
861 
862         let data1 = &[0, 1, 2, 3, 4];
863         let data2 = &[10, 11, 12, 13, 14];
864         s2.send(data2).expect("failed to send data2");
865         s1.send(data1).expect("failed to send data1");
866         let recv_data = &mut [0; 5];
867         s2.recv(recv_data).expect("failed to recv data");
868         assert_eq!(data1, recv_data);
869         s1.recv(recv_data).expect("failed to recv data");
870         assert_eq!(data2, recv_data);
871     }
872 
873     #[test]
unix_seqpacket_zero_timeout()874     fn unix_seqpacket_zero_timeout() {
875         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
876         // Timeouts less than a microsecond are too small and round to zero.
877         s1.set_read_timeout(Some(Duration::from_nanos(10)))
878             .expect_err("successfully set zero timeout");
879     }
880 
881     #[test]
unix_seqpacket_read_timeout()882     fn unix_seqpacket_read_timeout() {
883         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
884         s1.set_read_timeout(Some(Duration::from_millis(1)))
885             .expect("failed to set read timeout for socket");
886         let _ = s1.recv(&mut [0]);
887     }
888 
889     #[test]
unix_seqpacket_write_timeout()890     fn unix_seqpacket_write_timeout() {
891         let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
892         s1.set_write_timeout(Some(Duration::from_millis(1)))
893             .expect("failed to set write timeout for socket");
894     }
895 
896     #[test]
unix_seqpacket_send_recv()897     fn unix_seqpacket_send_recv() {
898         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
899         let data1 = &[0, 1, 2, 3, 4];
900         let data2 = &[10, 11, 12, 13, 14];
901         s2.send(data2).expect("failed to send data2");
902         s1.send(data1).expect("failed to send data1");
903         let recv_data = &mut [0; 5];
904         s2.recv(recv_data).expect("failed to recv data");
905         assert_eq!(data1, recv_data);
906         s1.recv(recv_data).expect("failed to recv data");
907         assert_eq!(data2, recv_data);
908     }
909 
910     #[test]
unix_seqpacket_send_fragments()911     fn unix_seqpacket_send_fragments() {
912         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
913         let data1 = &[0, 1, 2, 3, 4];
914         let data2 = &[10, 11, 12, 13, 14, 15, 16];
915         s1.send(data1).expect("failed to send data1");
916         s1.send(data2).expect("failed to send data2");
917 
918         let recv_data = &mut [0; 32];
919         let size = s2.recv(recv_data).expect("failed to recv data");
920         assert_eq!(size, data1.len());
921         assert_eq!(data1, &recv_data[0..size]);
922 
923         let size = s2.recv(recv_data).expect("failed to recv data");
924         assert_eq!(size, data2.len());
925         assert_eq!(data2, &recv_data[0..size]);
926     }
927 
928     #[test]
unix_seqpacket_get_readable_bytes()929     fn unix_seqpacket_get_readable_bytes() {
930         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
931         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
932         assert_eq!(s2.get_readable_bytes().unwrap(), 0);
933         let data1 = &[0, 1, 2, 3, 4];
934         s1.send(data1).expect("failed to send data");
935 
936         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
937         assert_eq!(s2.get_readable_bytes().unwrap(), data1.len());
938 
939         let recv_data = &mut [0; 5];
940         s2.recv(recv_data).expect("failed to recv data");
941         assert_eq!(s1.get_readable_bytes().unwrap(), 0);
942         assert_eq!(s2.get_readable_bytes().unwrap(), 0);
943     }
944 
945     #[test]
unix_seqpacket_next_packet_size()946     fn unix_seqpacket_next_packet_size() {
947         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
948         let data1 = &[0, 1, 2, 3, 4];
949         s1.send(data1).expect("failed to send data");
950 
951         assert_eq!(s2.next_packet_size().unwrap(), 5);
952         s1.set_read_timeout(Some(Duration::from_micros(1)))
953             .expect("failed to set read timeout");
954         assert_eq!(
955             s1.next_packet_size().unwrap_err().kind(),
956             ErrorKind::WouldBlock
957         );
958         drop(s2);
959         assert_eq!(
960             s1.next_packet_size().unwrap_err().kind(),
961             ErrorKind::ConnectionReset
962         );
963     }
964 
965     #[test]
unix_seqpacket_recv_to_vec()966     fn unix_seqpacket_recv_to_vec() {
967         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
968         let data1 = &[0, 1, 2, 3, 4];
969         s1.send(data1).expect("failed to send data");
970 
971         let recv_data = &mut vec![];
972         s2.recv_to_vec(recv_data).expect("failed to recv data");
973         assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
974     }
975 
976     #[test]
unix_seqpacket_recv_as_vec()977     fn unix_seqpacket_recv_as_vec() {
978         let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
979         let data1 = &[0, 1, 2, 3, 4];
980         s1.send(data1).expect("failed to send data");
981 
982         let recv_data = s2.recv_as_vec().expect("failed to recv data");
983         assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
984     }
985 }
986