1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::{
6 cmp::Ordering,
7 convert::TryFrom,
8 ffi::OsString,
9 fs::remove_file,
10 io,
11 mem::{
12 size_of, {self},
13 },
14 net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs},
15 ops::Deref,
16 os::unix::{
17 ffi::{OsStrExt, OsStringExt},
18 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
19 },
20 path::{Path, PathBuf},
21 ptr::null_mut,
22 time::{Duration, Instant},
23 };
24
25 use libc::{
26 c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
27 socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
28 };
29 use serde::{Deserialize, Serialize};
30
31 use super::{
32 sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT},
33 Error, RawDescriptor,
34 };
35 use crate::descriptor::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor};
36
37 /// Assist in handling both IP version 4 and IP version 6.
38 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
39 pub enum InetVersion {
40 V4,
41 V6,
42 }
43
44 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self45 pub fn from_sockaddr(s: &SocketAddr) -> Self {
46 match s {
47 SocketAddr::V4(_) => InetVersion::V4,
48 SocketAddr::V6(_) => InetVersion::V6,
49 }
50 }
51 }
52
53 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t54 fn from(v: InetVersion) -> sa_family_t {
55 match v {
56 InetVersion::V4 => AF_INET as sa_family_t,
57 InetVersion::V6 => AF_INET6 as sa_family_t,
58 }
59 }
60 }
61
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in62 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
63 sockaddr_in {
64 sin_family: AF_INET as sa_family_t,
65 sin_port: s.port().to_be(),
66 sin_addr: in_addr {
67 s_addr: u32::from_ne_bytes(s.ip().octets()),
68 },
69 sin_zero: [0; 8],
70 }
71 }
72
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in673 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
74 sockaddr_in6 {
75 sin6_family: AF_INET6 as sa_family_t,
76 sin6_port: s.port().to_be(),
77 sin6_flowinfo: 0,
78 sin6_addr: in6_addr {
79 s6_addr: s.ip().octets(),
80 },
81 sin6_scope_id: 0,
82 }
83 }
84
85 /// A TCP socket.
86 ///
87 /// Do not use this class unless you need to change socket options or query the
88 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
89 /// TcpListener.
90 #[derive(Debug)]
91 pub struct TcpSocket {
92 inet_version: InetVersion,
93 fd: RawFd,
94 }
95
96 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>97 pub fn new(inet_version: InetVersion) -> io::Result<Self> {
98 let fd = unsafe {
99 libc::socket(
100 Into::<sa_family_t>::into(inet_version) as c_int,
101 SOCK_STREAM | SOCK_CLOEXEC,
102 0,
103 )
104 };
105 if fd < 0 {
106 Err(io::Error::last_os_error())
107 } else {
108 Ok(TcpSocket { inet_version, fd })
109 }
110 }
111
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>112 pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
113 let sockaddr = addr
114 .to_socket_addrs()
115 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
116 .next()
117 .unwrap();
118
119 let ret = match sockaddr {
120 SocketAddr::V4(a) => {
121 let sin = sockaddrv4_to_lib_c(&a);
122 // Safe because this doesn't modify any memory and we check the return value.
123 unsafe {
124 libc::bind(
125 self.fd,
126 &sin as *const sockaddr_in as *const sockaddr,
127 size_of::<sockaddr_in>() as socklen_t,
128 )
129 }
130 }
131 SocketAddr::V6(a) => {
132 let sin6 = sockaddrv6_to_lib_c(&a);
133 // Safe because this doesn't modify any memory and we check the return value.
134 unsafe {
135 libc::bind(
136 self.fd,
137 &sin6 as *const sockaddr_in6 as *const sockaddr,
138 size_of::<sockaddr_in6>() as socklen_t,
139 )
140 }
141 }
142 };
143 if ret < 0 {
144 let bind_err = io::Error::last_os_error();
145 Err(bind_err)
146 } else {
147 Ok(())
148 }
149 }
150
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>151 pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
152 let sockaddr = addr
153 .to_socket_addrs()
154 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
155 .next()
156 .unwrap();
157
158 let ret = match sockaddr {
159 SocketAddr::V4(a) => {
160 let sin = sockaddrv4_to_lib_c(&a);
161 // Safe because this doesn't modify any memory and we check the return value.
162 unsafe {
163 libc::connect(
164 self.fd,
165 &sin as *const sockaddr_in as *const sockaddr,
166 size_of::<sockaddr_in>() as socklen_t,
167 )
168 }
169 }
170 SocketAddr::V6(a) => {
171 let sin6 = sockaddrv6_to_lib_c(&a);
172 // Safe because this doesn't modify any memory and we check the return value.
173 unsafe {
174 libc::connect(
175 self.fd,
176 &sin6 as *const sockaddr_in6 as *const sockaddr,
177 size_of::<sockaddr_in>() as socklen_t,
178 )
179 }
180 }
181 };
182
183 if ret < 0 {
184 let connect_err = io::Error::last_os_error();
185 Err(connect_err)
186 } else {
187 // Safe because the ownership of the raw fd is released from self and taken over by the
188 // new TcpStream.
189 Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
190 }
191 }
192
listen(self) -> io::Result<TcpListener>193 pub fn listen(self) -> io::Result<TcpListener> {
194 // Safe because this doesn't modify any memory and we check the return value.
195 let ret = unsafe { libc::listen(self.fd, 1) };
196 if ret < 0 {
197 let listen_err = io::Error::last_os_error();
198 Err(listen_err)
199 } else {
200 // Safe because the ownership of the raw fd is released from self and taken over by the
201 // new TcpListener.
202 Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
203 }
204 }
205
206 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>207 pub fn local_port(&self) -> io::Result<u16> {
208 match self.inet_version {
209 InetVersion::V4 => {
210 let mut sin = sockaddr_in {
211 sin_family: 0,
212 sin_port: 0,
213 sin_addr: in_addr { s_addr: 0 },
214 sin_zero: [0; 8],
215 };
216
217 // Safe because we give a valid pointer for addrlen and check the length.
218 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
219 let ret = unsafe {
220 // Get the socket address that was actually bound.
221 libc::getsockname(
222 self.fd,
223 &mut sin as *mut sockaddr_in as *mut sockaddr,
224 &mut addrlen as *mut socklen_t,
225 )
226 };
227 if ret < 0 {
228 let getsockname_err = io::Error::last_os_error();
229 Err(getsockname_err)
230 } else {
231 // If this doesn't match, it's not safe to get the port out of the sockaddr.
232 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
233
234 Ok(u16::from_be(sin.sin_port))
235 }
236 }
237 InetVersion::V6 => {
238 let mut sin6 = sockaddr_in6 {
239 sin6_family: 0,
240 sin6_port: 0,
241 sin6_flowinfo: 0,
242 sin6_addr: in6_addr { s6_addr: [0; 16] },
243 sin6_scope_id: 0,
244 };
245
246 // Safe because we give a valid pointer for addrlen and check the length.
247 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
248 let ret = unsafe {
249 // Get the socket address that was actually bound.
250 libc::getsockname(
251 self.fd,
252 &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
253 &mut addrlen as *mut socklen_t,
254 )
255 };
256 if ret < 0 {
257 let getsockname_err = io::Error::last_os_error();
258 Err(getsockname_err)
259 } else {
260 // If this doesn't match, it's not safe to get the port out of the sockaddr.
261 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
262
263 Ok(u16::from_be(sin6.sin6_port))
264 }
265 }
266 }
267 }
268 }
269
270 impl IntoRawFd for TcpSocket {
into_raw_fd(self) -> RawFd271 fn into_raw_fd(self) -> RawFd {
272 let fd = self.fd;
273 mem::forget(self);
274 fd
275 }
276 }
277
278 impl AsRawFd for TcpSocket {
as_raw_fd(&self) -> RawFd279 fn as_raw_fd(&self) -> RawFd {
280 self.fd
281 }
282 }
283
284 impl Drop for TcpSocket {
drop(&mut self)285 fn drop(&mut self) {
286 // Safe because this doesn't modify any memory and we are the only
287 // owner of the file descriptor.
288 unsafe { libc::close(self.fd) };
289 }
290 }
291
292 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize293 fn sun_path_offset() -> usize {
294 // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
295 #[allow(clippy::zero_ptr)]
296 let addr = 0 as *const libc::sockaddr_un;
297 // Safe because we only use the dereference to create a pointer to the desired field in
298 // calculating the offset.
299 unsafe { &(*addr).sun_path as *const _ as usize }
300 }
301
302 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>303 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
304 let mut addr = libc::sockaddr_un {
305 sun_family: libc::AF_UNIX as libc::sa_family_t,
306 sun_path: [0; 108],
307 };
308
309 // Check if the input path is valid. Since
310 // * The pathname in sun_path should be null-terminated.
311 // * The length of the pathname, including the terminating null byte,
312 // should not exceed the size of sun_path.
313 //
314 // and our input is a `Path`, we only need to check
315 // * If the string size of `Path` should less than sizeof(sun_path)
316 // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
317 //
318 // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
319 let bytes = path.as_ref().as_os_str().as_bytes();
320 if bytes.len() >= addr.sun_path.len() {
321 return Err(io::Error::new(
322 io::ErrorKind::InvalidInput,
323 "Input path size should be less than the length of sun_path.",
324 ));
325 };
326
327 // Copy data from `path` to `addr.sun_path`
328 for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
329 *dst = *src as libc::c_char;
330 }
331
332 // The addrlen argument that describes the enclosing sockaddr_un structure
333 // should have a value of at least:
334 //
335 // offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
336 //
337 // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
338 let len = sun_path_offset() + bytes.len() + 1;
339 Ok((addr, len as libc::socklen_t))
340 }
341
342 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
343 #[derive(Debug, Serialize, Deserialize)]
344 pub struct UnixSeqpacket {
345 #[serde(with = "super::with_raw_descriptor")]
346 fd: RawFd,
347 }
348
349 impl UnixSeqpacket {
350 /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
351 ///
352 /// # Arguments
353 /// * `path` - Path to `SOCK_SEQPACKET` socket
354 ///
355 /// # Returns
356 /// A `UnixSeqpacket` structure point to the socket
357 ///
358 /// # Errors
359 /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>360 pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
361 // Safe socket initialization since we handle the returned error.
362 let fd = unsafe {
363 match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
364 -1 => return Err(io::Error::last_os_error()),
365 fd => fd,
366 }
367 };
368
369 let (addr, len) = sockaddr_un(path.as_ref())?;
370 // Safe connect since we handle the error and use the right length generated from
371 // `sockaddr_un`.
372 unsafe {
373 let ret = libc::connect(fd, &addr as *const _ as *const _, len);
374 if ret < 0 {
375 return Err(io::Error::last_os_error());
376 }
377 }
378 Ok(UnixSeqpacket { fd })
379 }
380
381 /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
382 ///
383 /// Both returned file descriptors have the `CLOEXEC` flag set.s
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>384 pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
385 let mut fds = [0, 0];
386 unsafe {
387 // Safe because we give enough space to store all the fds and we check the return value.
388 let ret = libc::socketpair(
389 libc::AF_UNIX,
390 libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
391 0,
392 &mut fds[0],
393 );
394 if ret == 0 {
395 Ok((
396 UnixSeqpacket::from_raw_fd(fds[0]),
397 UnixSeqpacket::from_raw_fd(fds[1]),
398 ))
399 } else {
400 Err(io::Error::last_os_error())
401 }
402 }
403 }
404
405 /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>406 pub fn try_clone(&self) -> io::Result<Self> {
407 // Safe because this doesn't modify any memory and we check the return value.
408 let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
409 if fd < 0 {
410 Err(io::Error::last_os_error())
411 } else {
412 Ok(Self { fd })
413 }
414 }
415
416 /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>417 pub fn get_readable_bytes(&self) -> io::Result<usize> {
418 let mut byte_count = 0i32;
419 let ret = unsafe { libc::ioctl(self.fd, libc::FIONREAD, &mut byte_count) };
420 if ret < 0 {
421 Err(io::Error::last_os_error())
422 } else {
423 Ok(byte_count as usize)
424 }
425 }
426
427 /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
428 /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>429 pub fn next_packet_size(&self) -> io::Result<usize> {
430 #[cfg(not(debug_assertions))]
431 let buf = null_mut();
432 // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
433 // This only matters for running the unit tests for a non-native architecture. See the
434 // upstream thread for the qemu fix:
435 // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
436 #[cfg(debug_assertions)]
437 let buf = &mut 0 as *mut _ as *mut _;
438
439 // This form of recvfrom doesn't modify any data because all null pointers are used. We only
440 // use the return value and check for errors on an FD owned by this structure.
441 let ret = unsafe {
442 recvfrom(
443 self.fd,
444 buf,
445 0,
446 MSG_TRUNC | MSG_PEEK,
447 null_mut(),
448 null_mut(),
449 )
450 };
451 if ret < 0 {
452 Err(io::Error::last_os_error())
453 } else {
454 Ok(ret as usize)
455 }
456 }
457
458 /// Write data from a given buffer to the socket fd
459 ///
460 /// # Arguments
461 /// * `buf` - A reference to the data buffer.
462 ///
463 /// # Returns
464 /// * `usize` - The size of bytes written to the buffer.
465 ///
466 /// # Errors
467 /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>468 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
469 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
470 unsafe {
471 let ret = libc::write(self.fd, buf.as_ptr() as *const _, buf.len());
472 if ret < 0 {
473 Err(io::Error::last_os_error())
474 } else {
475 Ok(ret as usize)
476 }
477 }
478 }
479
480 /// Read data from the socket fd to a given buffer
481 ///
482 /// # Arguments
483 /// * `buf` - A mut reference to the data buffer.
484 ///
485 /// # Returns
486 /// * `usize` - The size of bytes read to the buffer.
487 ///
488 /// # Errors
489 /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>490 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
491 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
492 unsafe {
493 let ret = libc::read(self.fd, buf.as_mut_ptr() as *mut _, buf.len());
494 if ret < 0 {
495 Err(io::Error::last_os_error())
496 } else {
497 Ok(ret as usize)
498 }
499 }
500 }
501
502 /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
503 ///
504 /// # Arguments
505 /// * `buf` - A mut reference to a `Vec` to resize and read into.
506 ///
507 /// # Errors
508 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>509 pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
510 let packet_size = self.next_packet_size()?;
511 buf.resize(packet_size, 0);
512 let read_bytes = self.recv(buf)?;
513 buf.resize(read_bytes, 0);
514 Ok(())
515 }
516
517 /// Read data from the socket fd to a new `Vec`.
518 ///
519 /// # Returns
520 /// * `vec` - A new `Vec` with the entire received packet.
521 ///
522 /// # Errors
523 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>524 pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
525 let mut buf = Vec::new();
526 self.recv_to_vec(&mut buf)?;
527 Ok(buf)
528 }
529
530 /// Read data and fds from the socket fd to a new pair of `Vec`.
531 ///
532 /// # Returns
533 /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
534 /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
535 ///
536 /// # Errors
537 /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>538 pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
539 let packet_size = self.next_packet_size()?;
540 let mut buf = vec![0; packet_size];
541 let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
542 let (read_bytes, read_fds) =
543 self.recv_with_fds(io::IoSliceMut::new(&mut buf), &mut fd_buf)?;
544 buf.resize(read_bytes, 0);
545 fd_buf.resize(read_fds, -1);
546 Ok((buf, fd_buf))
547 }
548
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>549 fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
550 let timeval = match timeout {
551 Some(t) => {
552 if t.as_secs() == 0 && t.subsec_micros() == 0 {
553 return Err(io::Error::new(
554 io::ErrorKind::InvalidInput,
555 "zero timeout duration is invalid",
556 ));
557 }
558 // subsec_micros fits in i32 because it is defined to be less than one million.
559 let nsec = t.subsec_micros() as i32;
560 libc::timeval {
561 tv_sec: t.as_secs() as libc::time_t,
562 tv_usec: libc::suseconds_t::from(nsec),
563 }
564 }
565 None => libc::timeval {
566 tv_sec: 0,
567 tv_usec: 0,
568 },
569 };
570 // Safe because we own the fd, and the length of the pointer's data is the same as the
571 // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
572 // and the return value is checked.
573 let ret = unsafe {
574 libc::setsockopt(
575 self.fd,
576 libc::SOL_SOCKET,
577 kind,
578 &timeval as *const libc::timeval as *const libc::c_void,
579 mem::size_of::<libc::timeval>() as libc::socklen_t,
580 )
581 };
582 if ret < 0 {
583 Err(io::Error::last_os_error())
584 } else {
585 Ok(())
586 }
587 }
588
589 /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>590 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
591 self.set_timeout(timeout, libc::SO_RCVTIMEO)
592 }
593
594 /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>595 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
596 self.set_timeout(timeout, libc::SO_SNDTIMEO)
597 }
598 }
599
600 impl Drop for UnixSeqpacket {
drop(&mut self)601 fn drop(&mut self) {
602 // Safe if the UnixSeqpacket is created from Self::connect.
603 unsafe {
604 libc::close(self.fd);
605 }
606 }
607 }
608
609 impl FromRawFd for UnixSeqpacket {
610 // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self611 unsafe fn from_raw_fd(fd: RawFd) -> Self {
612 Self { fd }
613 }
614 }
615
616 impl FromRawDescriptor for UnixSeqpacket {
from_raw_descriptor(descriptor: RawDescriptor) -> Self617 unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self {
618 Self { fd: descriptor }
619 }
620 }
621
622 impl IntoRawDescriptor for UnixSeqpacket {
into_raw_descriptor(self) -> RawDescriptor623 fn into_raw_descriptor(self) -> RawDescriptor {
624 let fd = self.fd;
625 mem::forget(self);
626 fd
627 }
628 }
629
630 impl AsRawFd for UnixSeqpacket {
as_raw_fd(&self) -> RawFd631 fn as_raw_fd(&self) -> RawFd {
632 self.fd
633 }
634 }
635
636 impl AsRawFd for &UnixSeqpacket {
as_raw_fd(&self) -> RawFd637 fn as_raw_fd(&self) -> RawFd {
638 self.fd
639 }
640 }
641
642 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor643 fn as_raw_descriptor(&self) -> RawDescriptor {
644 self.fd
645 }
646 }
647
648 impl io::Read for UnixSeqpacket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>649 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
650 self.recv(buf)
651 }
652 }
653
654 impl io::Write for UnixSeqpacket {
write(&mut self, buf: &[u8]) -> io::Result<usize>655 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
656 self.send(buf)
657 }
658
flush(&mut self) -> io::Result<()>659 fn flush(&mut self) -> io::Result<()> {
660 Ok(())
661 }
662 }
663
664 impl IntoRawFd for UnixSeqpacket {
into_raw_fd(self) -> RawFd665 fn into_raw_fd(self) -> RawFd {
666 let fd = self.fd;
667 mem::forget(self);
668 fd
669 }
670 }
671
672 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
673 pub struct UnixSeqpacketListener {
674 fd: RawFd,
675 }
676
677 impl UnixSeqpacketListener {
678 /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>679 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
680 // Safe socket initialization since we handle the returned error.
681 let fd = unsafe {
682 match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
683 -1 => return Err(io::Error::last_os_error()),
684 fd => fd,
685 }
686 };
687
688 let (addr, len) = sockaddr_un(path.as_ref())?;
689 // Safe connect since we handle the error and use the right length generated from
690 // `sockaddr_un`.
691 unsafe {
692 let ret = handle_eintr_errno!(libc::bind(fd, &addr as *const _ as *const _, len));
693 if ret < 0 {
694 return Err(io::Error::last_os_error());
695 }
696 let ret = handle_eintr_errno!(libc::listen(fd, 128));
697 if ret < 0 {
698 return Err(io::Error::last_os_error());
699 }
700 }
701 Ok(UnixSeqpacketListener { fd })
702 }
703
704 /// Blocks for and accepts a new incoming connection and returns the socket associated with that
705 /// connection.
706 ///
707 /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>708 pub fn accept(&self) -> io::Result<UnixSeqpacket> {
709 // Safe because we own this fd and the kernel will not write to null pointers.
710 let ret = unsafe { libc::accept4(self.fd, null_mut(), null_mut(), libc::SOCK_CLOEXEC) };
711 if ret < 0 {
712 return Err(io::Error::last_os_error());
713 }
714 // Safe because we checked the return value of accept. Therefore, the return value must be a
715 // valid socket.
716 Ok(unsafe { UnixSeqpacket::from_raw_fd(ret) })
717 }
718
accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket>719 pub fn accept_with_timeout(&self, timeout: Duration) -> io::Result<UnixSeqpacket> {
720 let start = Instant::now();
721
722 loop {
723 let mut fds = libc::pollfd {
724 fd: self.fd,
725 events: libc::POLLIN,
726 revents: 0,
727 };
728 let elapsed = Instant::now().saturating_duration_since(start);
729 let remaining = timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO);
730 let cur_timeout_ms = i32::try_from(remaining.as_millis()).unwrap_or(i32::MAX);
731 // Safe because we give a valid pointer to a list (of 1) FD and we check
732 // the return value.
733 match unsafe { libc::poll(&mut fds, 1, cur_timeout_ms) }.cmp(&0) {
734 Ordering::Greater => return self.accept(),
735 Ordering::Equal => return Err(io::Error::from_raw_os_error(libc::ETIMEDOUT)),
736 Ordering::Less => {
737 if Error::last() != Error::new(libc::EINTR) {
738 return Err(io::Error::last_os_error());
739 }
740 }
741 }
742 }
743 }
744
745 /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>746 pub fn path(&self) -> io::Result<PathBuf> {
747 let mut addr = libc::sockaddr_un {
748 sun_family: libc::AF_UNIX as libc::sa_family_t,
749 sun_path: [0; 108],
750 };
751 let sun_path_offset = (&addr.sun_path as *const _ as usize
752 - &addr.sun_family as *const _ as usize)
753 as libc::socklen_t;
754 let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
755 // Safe because the length given matches the length of the data of the given pointer, and we
756 // check the return value.
757 let ret = unsafe {
758 handle_eintr_errno!(libc::getsockname(
759 self.fd,
760 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
761 &mut len
762 ))
763 };
764 if ret < 0 {
765 return Err(io::Error::last_os_error());
766 }
767 if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
768 || addr.sun_path[0] == 0
769 || len < 1 + sun_path_offset
770 {
771 return Err(io::Error::new(
772 io::ErrorKind::InvalidInput,
773 "getsockname on socket returned invalid value",
774 ));
775 }
776
777 let path_os_str = OsString::from_vec(
778 addr.sun_path[..(len - sun_path_offset - 1) as usize]
779 .iter()
780 .map(|&c| c as _)
781 .collect(),
782 );
783 Ok(path_os_str.into())
784 }
785 }
786
787 impl Drop for UnixSeqpacketListener {
drop(&mut self)788 fn drop(&mut self) {
789 // Safe if the UnixSeqpacketListener is created from Self::listen.
790 unsafe {
791 libc::close(self.fd);
792 }
793 }
794 }
795
796 impl FromRawFd for UnixSeqpacketListener {
797 // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self798 unsafe fn from_raw_fd(fd: RawFd) -> Self {
799 Self { fd }
800 }
801 }
802
803 impl AsRawFd for UnixSeqpacketListener {
as_raw_fd(&self) -> RawFd804 fn as_raw_fd(&self) -> RawFd {
805 self.fd
806 }
807 }
808
809 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
810 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
811 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener812 fn as_ref(&self) -> &UnixSeqpacketListener {
813 &self.0
814 }
815 }
816
817 impl AsRawFd for UnlinkUnixSeqpacketListener {
as_raw_fd(&self) -> RawFd818 fn as_raw_fd(&self) -> RawFd {
819 self.0.as_raw_fd()
820 }
821 }
822
823 impl Deref for UnlinkUnixSeqpacketListener {
824 type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target825 fn deref(&self) -> &Self::Target {
826 &self.0
827 }
828 }
829
830 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)831 fn drop(&mut self) {
832 if let Ok(path) = self.0.path() {
833 if let Err(e) = remove_file(path) {
834 warn!("failed to remove control socket file: {:?}", e);
835 }
836 }
837 }
838 }
839
840 #[cfg(test)]
841 mod tests {
842 use super::*;
843 use std::{env, io::ErrorKind, path::PathBuf};
844
tmpdir() -> PathBuf845 fn tmpdir() -> PathBuf {
846 env::temp_dir()
847 }
848
849 #[test]
sockaddr_un_zero_length_input()850 fn sockaddr_un_zero_length_input() {
851 let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
852 }
853
854 #[test]
sockaddr_un_long_input_err()855 fn sockaddr_un_long_input_err() {
856 let res = sockaddr_un(Path::new(&"a".repeat(108)));
857 assert!(res.is_err());
858 }
859
860 #[test]
sockaddr_un_long_input_pass()861 fn sockaddr_un_long_input_pass() {
862 let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
863 }
864
865 #[test]
sockaddr_un_len_check()866 fn sockaddr_un_len_check() {
867 let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
868 assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
869 }
870
871 #[test]
872 #[allow(clippy::unnecessary_cast)]
sockaddr_un_pass()873 fn sockaddr_un_pass() {
874 let path_size = 50;
875 let (addr, len) =
876 sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
877 assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
878 assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
879
880 // Check `sun_path` in returned `sockaddr_un`
881 let mut ref_sun_path = [0 as libc::c_char; 108];
882 for path in ref_sun_path.iter_mut().take(path_size) {
883 *path = 'a' as libc::c_char;
884 }
885
886 for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
887 assert_eq!(addr_char, ref_char);
888 }
889 }
890
891 #[test]
unix_seqpacket_path_not_exists()892 fn unix_seqpacket_path_not_exists() {
893 let res = UnixSeqpacket::connect("/path/not/exists");
894 assert!(res.is_err());
895 }
896
897 #[test]
unix_seqpacket_listener_path()898 fn unix_seqpacket_listener_path() {
899 let mut socket_path = tmpdir();
900 socket_path.push("unix_seqpacket_listener_path");
901 let listener = UnlinkUnixSeqpacketListener(
902 UnixSeqpacketListener::bind(&socket_path)
903 .expect("failed to create UnixSeqpacketListener"),
904 );
905 let listener_path = listener.path().expect("failed to get socket listener path");
906 assert_eq!(socket_path, listener_path);
907 }
908
909 #[test]
unix_seqpacket_path_exists_pass()910 fn unix_seqpacket_path_exists_pass() {
911 let mut socket_path = tmpdir();
912 socket_path.push("path_to_socket");
913 let _listener = UnlinkUnixSeqpacketListener(
914 UnixSeqpacketListener::bind(&socket_path)
915 .expect("failed to create UnixSeqpacketListener"),
916 );
917 let _res =
918 UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
919 }
920
921 #[test]
unix_seqpacket_path_listener_accept_with_timeout()922 fn unix_seqpacket_path_listener_accept_with_timeout() {
923 let mut socket_path = tmpdir();
924 socket_path.push("path_listerner_accept_with_timeout");
925 let listener = UnlinkUnixSeqpacketListener(
926 UnixSeqpacketListener::bind(&socket_path)
927 .expect("failed to create UnixSeqpacketListener"),
928 );
929
930 for d in [Duration::from_millis(10), Duration::ZERO] {
931 let _ = listener.accept_with_timeout(d).expect_err(&format!(
932 "UnixSeqpacket::accept_with_timeout {:?} connected",
933 d
934 ));
935
936 let s1 = UnixSeqpacket::connect(socket_path.as_path())
937 .unwrap_or_else(|_| panic!("UnixSeqpacket::connect {:?} failed", d));
938
939 let s2 = listener
940 .accept_with_timeout(d)
941 .unwrap_or_else(|_| panic!("UnixSeqpacket::accept {:?} failed", d));
942
943 let data1 = &[0, 1, 2, 3, 4];
944 let data2 = &[10, 11, 12, 13, 14];
945 s2.send(data2).expect("failed to send data2");
946 s1.send(data1).expect("failed to send data1");
947 let recv_data = &mut [0; 5];
948 s2.recv(recv_data).expect("failed to recv data");
949 assert_eq!(data1, recv_data);
950 s1.recv(recv_data).expect("failed to recv data");
951 assert_eq!(data2, recv_data);
952 }
953 }
954
955 #[test]
unix_seqpacket_path_listener_accept()956 fn unix_seqpacket_path_listener_accept() {
957 let mut socket_path = tmpdir();
958 socket_path.push("path_listerner_accept");
959 let listener = UnlinkUnixSeqpacketListener(
960 UnixSeqpacketListener::bind(&socket_path)
961 .expect("failed to create UnixSeqpacketListener"),
962 );
963 let s1 =
964 UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
965
966 let s2 = listener.accept().expect("UnixSeqpacket::accept failed");
967
968 let data1 = &[0, 1, 2, 3, 4];
969 let data2 = &[10, 11, 12, 13, 14];
970 s2.send(data2).expect("failed to send data2");
971 s1.send(data1).expect("failed to send data1");
972 let recv_data = &mut [0; 5];
973 s2.recv(recv_data).expect("failed to recv data");
974 assert_eq!(data1, recv_data);
975 s1.recv(recv_data).expect("failed to recv data");
976 assert_eq!(data2, recv_data);
977 }
978
979 #[test]
unix_seqpacket_zero_timeout()980 fn unix_seqpacket_zero_timeout() {
981 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
982 // Timeouts less than a microsecond are too small and round to zero.
983 s1.set_read_timeout(Some(Duration::from_nanos(10)))
984 .expect_err("successfully set zero timeout");
985 }
986
987 #[test]
unix_seqpacket_read_timeout()988 fn unix_seqpacket_read_timeout() {
989 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
990 s1.set_read_timeout(Some(Duration::from_millis(1)))
991 .expect("failed to set read timeout for socket");
992 let _ = s1.recv(&mut [0]);
993 }
994
995 #[test]
unix_seqpacket_write_timeout()996 fn unix_seqpacket_write_timeout() {
997 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
998 s1.set_write_timeout(Some(Duration::from_millis(1)))
999 .expect("failed to set write timeout for socket");
1000 }
1001
1002 #[test]
unix_seqpacket_send_recv()1003 fn unix_seqpacket_send_recv() {
1004 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1005 let data1 = &[0, 1, 2, 3, 4];
1006 let data2 = &[10, 11, 12, 13, 14];
1007 s2.send(data2).expect("failed to send data2");
1008 s1.send(data1).expect("failed to send data1");
1009 let recv_data = &mut [0; 5];
1010 s2.recv(recv_data).expect("failed to recv data");
1011 assert_eq!(data1, recv_data);
1012 s1.recv(recv_data).expect("failed to recv data");
1013 assert_eq!(data2, recv_data);
1014 }
1015
1016 #[test]
unix_seqpacket_send_fragments()1017 fn unix_seqpacket_send_fragments() {
1018 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1019 let data1 = &[0, 1, 2, 3, 4];
1020 let data2 = &[10, 11, 12, 13, 14, 15, 16];
1021 s1.send(data1).expect("failed to send data1");
1022 s1.send(data2).expect("failed to send data2");
1023
1024 let recv_data = &mut [0; 32];
1025 let size = s2.recv(recv_data).expect("failed to recv data");
1026 assert_eq!(size, data1.len());
1027 assert_eq!(data1, &recv_data[0..size]);
1028
1029 let size = s2.recv(recv_data).expect("failed to recv data");
1030 assert_eq!(size, data2.len());
1031 assert_eq!(data2, &recv_data[0..size]);
1032 }
1033
1034 #[test]
unix_seqpacket_get_readable_bytes()1035 fn unix_seqpacket_get_readable_bytes() {
1036 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1037 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1038 assert_eq!(s2.get_readable_bytes().unwrap(), 0);
1039 let data1 = &[0, 1, 2, 3, 4];
1040 s1.send(data1).expect("failed to send data");
1041
1042 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1043 assert_eq!(s2.get_readable_bytes().unwrap(), data1.len());
1044
1045 let recv_data = &mut [0; 5];
1046 s2.recv(recv_data).expect("failed to recv data");
1047 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
1048 assert_eq!(s2.get_readable_bytes().unwrap(), 0);
1049 }
1050
1051 #[test]
unix_seqpacket_next_packet_size()1052 fn unix_seqpacket_next_packet_size() {
1053 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1054 let data1 = &[0, 1, 2, 3, 4];
1055 s1.send(data1).expect("failed to send data");
1056
1057 assert_eq!(s2.next_packet_size().unwrap(), 5);
1058 s1.set_read_timeout(Some(Duration::from_micros(1)))
1059 .expect("failed to set read timeout");
1060 assert_eq!(
1061 s1.next_packet_size().unwrap_err().kind(),
1062 ErrorKind::WouldBlock
1063 );
1064 drop(s2);
1065 assert_eq!(
1066 s1.next_packet_size().unwrap_err().kind(),
1067 ErrorKind::ConnectionReset
1068 );
1069 }
1070
1071 #[test]
unix_seqpacket_recv_to_vec()1072 fn unix_seqpacket_recv_to_vec() {
1073 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1074 let data1 = &[0, 1, 2, 3, 4];
1075 s1.send(data1).expect("failed to send data");
1076
1077 let recv_data = &mut vec![];
1078 s2.recv_to_vec(recv_data).expect("failed to recv data");
1079 assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
1080 }
1081
1082 #[test]
unix_seqpacket_recv_as_vec()1083 fn unix_seqpacket_recv_as_vec() {
1084 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
1085 let data1 = &[0, 1, 2, 3, 4];
1086 s1.send(data1).expect("failed to send data");
1087
1088 let recv_data = s2.recv_as_vec().expect("failed to recv data");
1089 assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
1090 }
1091 }
1092