1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cmp::Ordering;
6 use std::convert::TryFrom;
7 use std::ffi::OsString;
8 use std::fs::remove_file;
9 use std::io;
10 use std::mem;
11 use std::mem::size_of;
12 use std::net::SocketAddr;
13 use std::net::SocketAddrV4;
14 use std::net::SocketAddrV6;
15 use std::net::TcpListener;
16 use std::net::TcpStream;
17 use std::net::ToSocketAddrs;
18 use std::ops::Deref;
19 use std::os::unix::ffi::OsStrExt;
20 use std::os::unix::ffi::OsStringExt;
21 use std::os::unix::io::RawFd;
22 use std::path::Path;
23 use std::path::PathBuf;
24 use std::ptr::null_mut;
25 use std::time::Duration;
26 use std::time::Instant;
27
28 use libc::c_int;
29 use libc::in6_addr;
30 use libc::in_addr;
31 use libc::recvfrom;
32 use libc::sa_family_t;
33 use libc::sockaddr;
34 use libc::sockaddr_in;
35 use libc::sockaddr_in6;
36 use libc::socklen_t;
37 use libc::AF_INET;
38 use libc::AF_INET6;
39 use libc::MSG_PEEK;
40 use libc::MSG_TRUNC;
41 use libc::SOCK_CLOEXEC;
42 use libc::SOCK_STREAM;
43 use log::warn;
44 use serde::Deserialize;
45 use serde::Serialize;
46
47 use super::sock_ctrl_msg::ScmSocket;
48 use super::sock_ctrl_msg::SCM_SOCKET_MAX_FD_COUNT;
49 use super::Error;
50 use super::RawDescriptor;
51 use crate::descriptor::AsRawDescriptor;
52 use crate::descriptor::FromRawDescriptor;
53 use crate::SafeDescriptor;
54
55 /// Assist in handling both IP version 4 and IP version 6.
56 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
57 pub enum InetVersion {
58 V4,
59 V6,
60 }
61
62 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self63 pub fn from_sockaddr(s: &SocketAddr) -> Self {
64 match s {
65 SocketAddr::V4(_) => InetVersion::V4,
66 SocketAddr::V6(_) => InetVersion::V6,
67 }
68 }
69 }
70
71 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t72 fn from(v: InetVersion) -> sa_family_t {
73 match v {
74 InetVersion::V4 => AF_INET as sa_family_t,
75 InetVersion::V6 => AF_INET6 as sa_family_t,
76 }
77 }
78 }
79
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in80 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
81 sockaddr_in {
82 sin_family: AF_INET as sa_family_t,
83 sin_port: s.port().to_be(),
84 sin_addr: in_addr {
85 s_addr: u32::from_ne_bytes(s.ip().octets()),
86 },
87 sin_zero: [0; 8],
88 }
89 }
90
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in691 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
92 sockaddr_in6 {
93 sin6_family: AF_INET6 as sa_family_t,
94 sin6_port: s.port().to_be(),
95 sin6_flowinfo: 0,
96 sin6_addr: in6_addr {
97 s6_addr: s.ip().octets(),
98 },
99 sin6_scope_id: 0,
100 }
101 }
102
socket(domain: c_int, sock_type: c_int, protocol: c_int) -> io::Result<SafeDescriptor>103 fn socket(domain: c_int, sock_type: c_int, protocol: c_int) -> io::Result<SafeDescriptor> {
104 // Safe socket initialization since we handle the returned error.
105 match unsafe { libc::socket(domain, sock_type, protocol) } {
106 -1 => Err(io::Error::last_os_error()),
107 // Safe because we own the file descriptor.
108 fd => Ok(unsafe { SafeDescriptor::from_raw_descriptor(fd) }),
109 }
110 }
111
socketpair( domain: c_int, sock_type: c_int, protocol: c_int, ) -> io::Result<(SafeDescriptor, SafeDescriptor)>112 fn socketpair(
113 domain: c_int,
114 sock_type: c_int,
115 protocol: c_int,
116 ) -> io::Result<(SafeDescriptor, SafeDescriptor)> {
117 let mut fds = [0, 0];
118 // Safe because we give enough space to store all the fds and we check the return value.
119 match unsafe { libc::socketpair(domain, sock_type, protocol, fds.as_mut_ptr()) } {
120 -1 => Err(io::Error::last_os_error()),
121 // Safe because we own the file descriptors.
122 _ => Ok(unsafe {
123 (
124 SafeDescriptor::from_raw_descriptor(fds[0]),
125 SafeDescriptor::from_raw_descriptor(fds[1]),
126 )
127 }),
128 }
129 }
130
131 /// A TCP socket.
132 ///
133 /// Do not use this class unless you need to change socket options or query the
134 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
135 /// TcpListener.
136 #[derive(Debug)]
137 pub struct TcpSocket {
138 inet_version: InetVersion,
139 descriptor: SafeDescriptor,
140 }
141
142 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>143 pub fn new(inet_version: InetVersion) -> io::Result<Self> {
144 Ok(TcpSocket {
145 inet_version,
146 descriptor: socket(
147 Into::<sa_family_t>::into(inet_version) as c_int,
148 SOCK_STREAM | SOCK_CLOEXEC,
149 0,
150 )?,
151 })
152 }
153
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>154 pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
155 let sockaddr = addr
156 .to_socket_addrs()
157 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
158 .next()
159 .unwrap();
160
161 let ret = match sockaddr {
162 SocketAddr::V4(a) => {
163 let sin = sockaddrv4_to_lib_c(&a);
164 // Safe because this doesn't modify any memory and we check the return value.
165 unsafe {
166 libc::bind(
167 self.as_raw_descriptor(),
168 &sin as *const sockaddr_in as *const sockaddr,
169 size_of::<sockaddr_in>() as socklen_t,
170 )
171 }
172 }
173 SocketAddr::V6(a) => {
174 let sin6 = sockaddrv6_to_lib_c(&a);
175 // Safe because this doesn't modify any memory and we check the return value.
176 unsafe {
177 libc::bind(
178 self.as_raw_descriptor(),
179 &sin6 as *const sockaddr_in6 as *const sockaddr,
180 size_of::<sockaddr_in6>() as socklen_t,
181 )
182 }
183 }
184 };
185 if ret < 0 {
186 let bind_err = io::Error::last_os_error();
187 Err(bind_err)
188 } else {
189 Ok(())
190 }
191 }
192
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>193 pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
194 let sockaddr = addr
195 .to_socket_addrs()
196 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
197 .next()
198 .unwrap();
199
200 let ret = match sockaddr {
201 SocketAddr::V4(a) => {
202 let sin = sockaddrv4_to_lib_c(&a);
203 // Safe because this doesn't modify any memory and we check the return value.
204 unsafe {
205 libc::connect(
206 self.as_raw_descriptor(),
207 &sin as *const sockaddr_in as *const sockaddr,
208 size_of::<sockaddr_in>() as socklen_t,
209 )
210 }
211 }
212 SocketAddr::V6(a) => {
213 let sin6 = sockaddrv6_to_lib_c(&a);
214 // Safe because this doesn't modify any memory and we check the return value.
215 unsafe {
216 libc::connect(
217 self.as_raw_descriptor(),
218 &sin6 as *const sockaddr_in6 as *const sockaddr,
219 size_of::<sockaddr_in>() as socklen_t,
220 )
221 }
222 }
223 };
224
225 if ret < 0 {
226 let connect_err = io::Error::last_os_error();
227 Err(connect_err)
228 } else {
229 Ok(TcpStream::from(self.descriptor))
230 }
231 }
232
listen(self) -> io::Result<TcpListener>233 pub fn listen(self) -> io::Result<TcpListener> {
234 // Safe because this doesn't modify any memory and we check the return value.
235 let ret = unsafe { libc::listen(self.as_raw_descriptor(), 1) };
236 if ret < 0 {
237 let listen_err = io::Error::last_os_error();
238 Err(listen_err)
239 } else {
240 Ok(TcpListener::from(self.descriptor))
241 }
242 }
243
244 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>245 pub fn local_port(&self) -> io::Result<u16> {
246 match self.inet_version {
247 InetVersion::V4 => {
248 let mut sin = sockaddr_in {
249 sin_family: 0,
250 sin_port: 0,
251 sin_addr: in_addr { s_addr: 0 },
252 sin_zero: [0; 8],
253 };
254
255 // Safe because we give a valid pointer for addrlen and check the length.
256 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
257 let ret = unsafe {
258 // Get the socket address that was actually bound.
259 libc::getsockname(
260 self.as_raw_descriptor(),
261 &mut sin as *mut sockaddr_in as *mut sockaddr,
262 &mut addrlen as *mut socklen_t,
263 )
264 };
265 if ret < 0 {
266 let getsockname_err = io::Error::last_os_error();
267 Err(getsockname_err)
268 } else {
269 // If this doesn't match, it's not safe to get the port out of the sockaddr.
270 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
271
272 Ok(u16::from_be(sin.sin_port))
273 }
274 }
275 InetVersion::V6 => {
276 let mut sin6 = sockaddr_in6 {
277 sin6_family: 0,
278 sin6_port: 0,
279 sin6_flowinfo: 0,
280 sin6_addr: in6_addr { s6_addr: [0; 16] },
281 sin6_scope_id: 0,
282 };
283
284 // Safe because we give a valid pointer for addrlen and check the length.
285 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
286 let ret = unsafe {
287 // Get the socket address that was actually bound.
288 libc::getsockname(
289 self.as_raw_descriptor(),
290 &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
291 &mut addrlen as *mut socklen_t,
292 )
293 };
294 if ret < 0 {
295 let getsockname_err = io::Error::last_os_error();
296 Err(getsockname_err)
297 } else {
298 // If this doesn't match, it's not safe to get the port out of the sockaddr.
299 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
300
301 Ok(u16::from_be(sin6.sin6_port))
302 }
303 }
304 }
305 }
306 }
307
308 impl AsRawDescriptor for TcpSocket {
as_raw_descriptor(&self) -> RawDescriptor309 fn as_raw_descriptor(&self) -> RawDescriptor {
310 self.descriptor.as_raw_descriptor()
311 }
312 }
313
314 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize315 fn sun_path_offset() -> usize {
316 // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
317 #[allow(clippy::zero_ptr)]
318 let addr = 0 as *const libc::sockaddr_un;
319 // Safe because we only use the dereference to create a pointer to the desired field in
320 // calculating the offset.
321 unsafe { &(*addr).sun_path as *const _ as usize }
322 }
323
324 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>325 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
326 let mut addr = libc::sockaddr_un {
327 sun_family: libc::AF_UNIX as libc::sa_family_t,
328 sun_path: [0; 108],
329 };
330
331 // Check if the input path is valid. Since
332 // * The pathname in sun_path should be null-terminated.
333 // * The length of the pathname, including the terminating null byte,
334 // should not exceed the size of sun_path.
335 //
336 // and our input is a `Path`, we only need to check
337 // * If the string size of `Path` should less than sizeof(sun_path)
338 // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
339 //
340 // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
341 let bytes = path.as_ref().as_os_str().as_bytes();
342 if bytes.len() >= addr.sun_path.len() {
343 return Err(io::Error::new(
344 io::ErrorKind::InvalidInput,
345 "Input path size should be less than the length of sun_path.",
346 ));
347 };
348
349 // Copy data from `path` to `addr.sun_path`
350 for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
351 *dst = *src as libc::c_char;
352 }
353
354 // The addrlen argument that describes the enclosing sockaddr_un structure
355 // should have a value of at least:
356 //
357 // offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
358 //
359 // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
360 let len = sun_path_offset() + bytes.len() + 1;
361 Ok((addr, len as libc::socklen_t))
362 }
363
364 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
365 #[derive(Debug, Serialize, Deserialize)]
366 pub struct UnixSeqpacket(SafeDescriptor);
367
368 // Safety: UnixSeqpacket is just a file descriptor, and it is safe to send FDs
369 // between threads.
370 unsafe impl Send for UnixSeqpacket {}
371
372 impl UnixSeqpacket {
373 /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
374 ///
375 /// # Arguments
376 /// * `path` - Path to `SOCK_SEQPACKET` socket
377 ///
378 /// # Returns
379 /// A `UnixSeqpacket` structure point to the socket
380 ///
381 /// # Errors
382 /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>383 pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
384 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
385 let (addr, len) = sockaddr_un(path.as_ref())?;
386 // Safe connect since we handle the error and use the right length generated from
387 // `sockaddr_un`.
388 unsafe {
389 let ret = libc::connect(
390 descriptor.as_raw_descriptor(),
391 &addr as *const _ as *const _,
392 len,
393 );
394 if ret < 0 {
395 return Err(io::Error::last_os_error());
396 }
397 }
398 Ok(UnixSeqpacket(descriptor))
399 }
400
401 /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
402 ///
403 /// Both returned file descriptors have the `CLOEXEC` flag set.
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>404 pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
405 socketpair(libc::AF_UNIX, libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC, 0)
406 .map(|(s0, s1)| (UnixSeqpacket::from(s0), UnixSeqpacket::from(s1)))
407 }
408
409 /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>410 pub fn try_clone(&self) -> io::Result<Self> {
411 Ok(Self(self.0.try_clone()?))
412 }
413
414 /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>415 pub fn get_readable_bytes(&self) -> io::Result<usize> {
416 let mut byte_count = 0i32;
417 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONREAD, &mut byte_count) };
418 if ret < 0 {
419 Err(io::Error::last_os_error())
420 } else {
421 Ok(byte_count as usize)
422 }
423 }
424
425 /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
426 /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>427 pub fn next_packet_size(&self) -> io::Result<usize> {
428 #[cfg(not(debug_assertions))]
429 let buf = null_mut();
430 // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
431 // This only matters for running the unit tests for a non-native architecture. See the
432 // upstream thread for the qemu fix:
433 // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
434 #[cfg(debug_assertions)]
435 let buf = &mut 0 as *mut _ as *mut _;
436
437 // This form of recvfrom doesn't modify any data because all null pointers are used. We only
438 // use the return value and check for errors on an FD owned by this structure.
439 let ret = unsafe {
440 recvfrom(
441 self.as_raw_descriptor(),
442 buf,
443 0,
444 MSG_TRUNC | MSG_PEEK,
445 null_mut(),
446 null_mut(),
447 )
448 };
449 if ret < 0 {
450 Err(io::Error::last_os_error())
451 } else {
452 Ok(ret as usize)
453 }
454 }
455
456 /// Write data from a given buffer to the socket fd
457 ///
458 /// # Arguments
459 /// * `buf` - A reference to the data buffer.
460 ///
461 /// # Returns
462 /// * `usize` - The size of bytes written to the buffer.
463 ///
464 /// # Errors
465 /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>466 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
467 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
468 unsafe {
469 let ret = libc::write(
470 self.as_raw_descriptor(),
471 buf.as_ptr() as *const _,
472 buf.len(),
473 );
474 if ret < 0 {
475 Err(io::Error::last_os_error())
476 } else {
477 Ok(ret as usize)
478 }
479 }
480 }
481
482 /// Read data from the socket fd to a given buffer
483 ///
484 /// # Arguments
485 /// * `buf` - A mut reference to the data buffer.
486 ///
487 /// # Returns
488 /// * `usize` - The size of bytes read to the buffer.
489 ///
490 /// # Errors
491 /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>492 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
493 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
494 unsafe {
495 let ret = libc::read(
496 self.as_raw_descriptor(),
497 buf.as_mut_ptr() as *mut _,
498 buf.len(),
499 );
500 if ret < 0 {
501 Err(io::Error::last_os_error())
502 } else {
503 Ok(ret as usize)
504 }
505 }
506 }
507
508 /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
509 ///
510 /// # Arguments
511 /// * `buf` - A mut reference to a `Vec` to resize and read into.
512 ///
513 /// # Errors
514 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>515 pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
516 let packet_size = self.next_packet_size()?;
517 buf.resize(packet_size, 0);
518 let read_bytes = self.recv(buf)?;
519 buf.resize(read_bytes, 0);
520 Ok(())
521 }
522
523 /// Read data from the socket fd to a new `Vec`.
524 ///
525 /// # Returns
526 /// * `vec` - A new `Vec` with the entire received packet.
527 ///
528 /// # Errors
529 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>530 pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
531 let mut buf = Vec::new();
532 self.recv_to_vec(&mut buf)?;
533 Ok(buf)
534 }
535
536 /// Read data and fds from the socket fd to a new pair of `Vec`.
537 ///
538 /// # Returns
539 /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
540 /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
541 ///
542 /// # Errors
543 /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>544 pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
545 let packet_size = self.next_packet_size()?;
546 let mut buf = vec![0; packet_size];
547 let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
548 let (read_bytes, read_fds) =
549 self.recv_with_fds(io::IoSliceMut::new(&mut buf), &mut fd_buf)?;
550 buf.resize(read_bytes, 0);
551 fd_buf.resize(read_fds, -1);
552 Ok((buf, fd_buf))
553 }
554
555 #[allow(clippy::useless_conversion)]
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>556 fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
557 let timeval = match timeout {
558 Some(t) => {
559 if t.as_secs() == 0 && t.subsec_micros() == 0 {
560 return Err(io::Error::new(
561 io::ErrorKind::InvalidInput,
562 "zero timeout duration is invalid",
563 ));
564 }
565 // subsec_micros fits in i32 because it is defined to be less than one million.
566 let nsec = t.subsec_micros() as i32;
567 libc::timeval {
568 tv_sec: t.as_secs() as libc::time_t,
569 tv_usec: libc::suseconds_t::from(nsec),
570 }
571 }
572 None => libc::timeval {
573 tv_sec: 0,
574 tv_usec: 0,
575 },
576 };
577 // Safe because we own the fd, and the length of the pointer's data is the same as the
578 // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
579 // and the return value is checked.
580 let ret = unsafe {
581 libc::setsockopt(
582 self.as_raw_descriptor(),
583 libc::SOL_SOCKET,
584 kind,
585 &timeval as *const libc::timeval as *const libc::c_void,
586 mem::size_of::<libc::timeval>() as libc::socklen_t,
587 )
588 };
589 if ret < 0 {
590 Err(io::Error::last_os_error())
591 } else {
592 Ok(())
593 }
594 }
595
596 /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>597 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
598 self.set_timeout(timeout, libc::SO_RCVTIMEO)
599 }
600
601 /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>602 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
603 self.set_timeout(timeout, libc::SO_SNDTIMEO)
604 }
605
606 /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>607 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
608 let mut nonblocking = nonblocking as libc::c_int;
609 // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
610 // and does not continue holding the file descriptor after the call.
611 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
612 if ret < 0 {
613 Err(io::Error::last_os_error())
614 } else {
615 Ok(())
616 }
617 }
618 }
619
620 impl From<UnixSeqpacket> for SafeDescriptor {
from(s: UnixSeqpacket) -> Self621 fn from(s: UnixSeqpacket) -> Self {
622 s.0
623 }
624 }
625
626 impl From<SafeDescriptor> for UnixSeqpacket {
from(s: SafeDescriptor) -> Self627 fn from(s: SafeDescriptor) -> Self {
628 Self(s)
629 }
630 }
631
632 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self633 unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
634 Self(SafeDescriptor::from_raw_descriptor(descriptor))
635 }
636 }
637
638 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor639 fn as_raw_descriptor(&self) -> RawDescriptor {
640 self.0.as_raw_descriptor()
641 }
642 }
643
644 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>645 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
646 self.recv(buf)
647 }
648 }
649
650 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>651 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
652 self.send(buf)
653 }
654
flush(&mut self) -> io::Result<()>655 fn flush(&mut self) -> io::Result<()> {
656 Ok(())
657 }
658 }
659
660 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
661 pub struct UnixSeqpacketListener {
662 descriptor: SafeDescriptor,
663 no_path: bool,
664 }
665
666 impl UnixSeqpacketListener {
667 /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>668 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
669 if path.as_ref().starts_with("/proc/self/fd/") {
670 let fd = path
671 .as_ref()
672 .file_name()
673 .expect("Failed to get fd filename")
674 .to_str()
675 .expect("fd filename should be unicode")
676 .parse::<i32>()
677 .expect("fd should be an integer");
678 let mut result: c_int = 0;
679 let mut result_len = size_of::<c_int>() as libc::socklen_t;
680 let ret = unsafe {
681 libc::getsockopt(
682 fd,
683 libc::SOL_SOCKET,
684 libc::SO_ACCEPTCONN,
685 &mut result as *mut _ as *mut libc::c_void,
686 &mut result_len,
687 )
688 };
689 if ret < 0 {
690 return Err(io::Error::last_os_error());
691 }
692 if result != 1 {
693 return Err(io::Error::new(
694 io::ErrorKind::InvalidInput,
695 "specified descriptor is not a listening socket",
696 ));
697 }
698 // Safe because we validated the socket file descriptor.
699 let descriptor = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
700 return Ok(UnixSeqpacketListener {
701 descriptor,
702 no_path: true,
703 });
704 }
705
706 let descriptor = socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0)?;
707 let (addr, len) = sockaddr_un(path.as_ref())?;
708
709 // Safe connect since we handle the error and use the right length generated from
710 // `sockaddr_un`.
711 unsafe {
712 let ret = handle_eintr_errno!(libc::bind(
713 descriptor.as_raw_descriptor(),
714 &addr as *const _ as *const _,
715 len
716 ));
717 if ret < 0 {
718 return Err(io::Error::last_os_error());
719 }
720 let ret = handle_eintr_errno!(libc::listen(descriptor.as_raw_descriptor(), 128));
721 if ret < 0 {
722 return Err(io::Error::last_os_error());
723 }
724 }
725 Ok(UnixSeqpacketListener {
726 descriptor,
727 no_path: false,
728 })
729 }
730
731 /// Blocks for and accepts a new incoming connection and returns the socket associated with that
732 /// connection.
733 ///
734 /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>735 pub fn accept(&self) -> io::Result<UnixSeqpacket> {
736 // Safe because we own this fd and the kernel will not write to null pointers.
737 match unsafe {
738 libc::accept4(
739 self.as_raw_descriptor(),
740 null_mut(),
741 null_mut(),
742 libc::SOCK_CLOEXEC,
743 )
744 } {
745 -1 => Err(io::Error::last_os_error()),
746 fd => {
747 // Safe because we checked the return value of accept. Therefore, the return value
748 // must be a valid socket.
749 Ok(UnixSeqpacket::from(unsafe {
750 SafeDescriptor::from_raw_descriptor(fd)
751 }))
752 }
753 }
754 }
755
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>756 pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
757 let start = Instant::now();
758
759 loop {
760 let mut fds = libc::pollfd {
761 fd: self.as_raw_descriptor(),
762 events: libc::POLLIN,
763 revents: 0,
764 };
765 let elapsed = Instant::now().saturating_duration_since(start);
766 let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
767 let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
768 // Safe because we give a valid pointer to a list (of 1) FD and we check
769 // the return value.
770 match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
771 Ordering::Greater => return self.accept(),
772 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
773 Ordering::Less => {
774 if Error::last() != Error::new(libc::EINTR) {
775 return Err(io::Error::last_os_error());
776 }
777 }
778 }
779 }
780 }
781
782 /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>783 pub fn path(&self) -> io::Result<PathBuf> {
784 let mut addr = libc::sockaddr_un {
785 sun_family: libc::AF_UNIX as libc::sa_family_t,
786 sun_path: [0; 108],
787 };
788 if self.no_path {
789 return Err(io::Error::new(
790 io::ErrorKind::InvalidInput,
791 "socket has no path",
792 ));
793 }
794 let sun_path_offset = (&addr.sun_path as *const _ as usize
795 - &addr.sun_family as *const _ as usize)
796 as libc::socklen_t;
797 let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
798 // Safe because the length given matches the length of the data of the given pointer, and we
799 // check the return value.
800 let ret = unsafe {
801 handle_eintr_errno!(libc::getsockname(
802 self.as_raw_descriptor(),
803 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
804 &mut len
805 ))
806 };
807 if ret < 0 {
808 return Err(io::Error::last_os_error());
809 }
810 if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
811 || addr.sun_path[0] == 0
812 || len < 1 + sun_path_offset
813 {
814 return Err(io::Error::new(
815 io::ErrorKind::InvalidInput,
816 "getsockname on socket returned invalid value",
817 ));
818 }
819
820 let path_os_str = OsString::from_vec(
821 addr.sun_path[..(len - sun_path_offset - 1) as usize]
822 .iter()
823 .map(|&c| c as _)
824 .collect(),
825 );
826 Ok(path_os_str.into())
827 }
828
829 /// Sets the blocking mode for this socket.
set_nonblocking(&self, nonblocking: bool) -> io::Result<()>830 pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
831 let mut nonblocking = nonblocking as libc::c_int;
832 // Safe because the return value is checked, and this ioctl call sets the nonblocking mode
833 // and does not continue holding the file descriptor after the call.
834 let ret = unsafe { libc::ioctl(self.as_raw_descriptor(), libc::FIONBIO, &mut nonblocking) };
835 if ret < 0 {
836 Err(io::Error::last_os_error())
837 } else {
838 Ok(())
839 }
840 }
841 }
842
843 impl AsRawDescriptor for UnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor844 fn as_raw_descriptor(&self) -> RawDescriptor {
845 self.descriptor.as_raw_descriptor()
846 }
847 }
848
849 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
850 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
851
852 impl AsRawDescriptor for UnlinkUnixSeqpacketListener {
as_raw_descriptor(&self) -> RawDescriptor853 fn as_raw_descriptor(&self) -> RawDescriptor {
854 self.0.as_raw_descriptor()
855 }
856 }
857
858 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener859 fn as_ref(&self) -> &UnixSeqpacketListener {
860 &self.0
861 }
862 }
863
864 impl Deref for UnlinkUnixSeqpacketListener {
865 type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target866 fn deref(&self) -> &Self::Target {
867 &self.0
868 }
869 }
870
871 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)872 fn drop(&mut self) {
873 if let Ok(path) = self.0.path() {
874 if let Err(e) = remove_file(path) {
875 warn!("failed to remove control socket file: {:?}", e);
876 }
877 }
878 }
879 }
880
881 #[cfg(test)]
882 mod tests {
883 use super::*;
884
885 #[test]
sockaddr_un_zero_length_input()886 fn sockaddr_un_zero_length_input() {
887 let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
888 }
889
890 #[test]
sockaddr_un_long_input_err()891 fn sockaddr_un_long_input_err() {
892 let res = sockaddr_un(Path::new(&"a".repeat(108)));
893 assert!(res.is_err());
894 }
895
896 #[test]
sockaddr_un_long_input_pass()897 fn sockaddr_un_long_input_pass() {
898 let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
899 }
900
901 #[test]
sockaddr_un_len_check()902 fn sockaddr_un_len_check() {
903 let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
904 assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
905 }
906
907 #[test]
908 #[allow(clippy::unnecessary_cast)]
909 // c_char is u8 on aarch64 and i8 on x86, so clippy's suggested fix of changing
910 // `'a' as libc::c_char` below to `b'a'` won't work everywhere.
911 #[allow(clippy::char_lit_as_u8)]
sockaddr_un_pass()912 fn sockaddr_un_pass() {
913 let path_size = 50;
914 let (addr, len) =
915 sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
916 assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
917 assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
918
919 // Check `sun_path` in returned `sockaddr_un`
920 let mut ref_sun_path = [0 as libc::c_char; 108];
921 for path in ref_sun_path.iter_mut().take(path_size) {
922 *path = 'a' as libc::c_char;
923 }
924
925 for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
926 assert_eq!(addr_char, ref_char);
927 }
928 }
929 }
930