1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::ffi::OsString;
6 use std::fs::remove_file;
7 use std::io;
8 use std::mem::{self, size_of};
9 use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
10 use std::ops::Deref;
11 use std::os::unix::{
12 ffi::{OsStrExt, OsStringExt},
13 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
14 };
15 use std::path::Path;
16 use std::path::PathBuf;
17 use std::ptr::null_mut;
18 use std::time::Duration;
19
20 use libc::{
21 c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
22 socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
23 };
24 use serde::{Deserialize, Serialize};
25
26 use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
27 use crate::{AsRawDescriptor, RawDescriptor};
28
29 /// Assist in handling both IP version 4 and IP version 6.
30 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
31 pub enum InetVersion {
32 V4,
33 V6,
34 }
35
36 impl InetVersion {
from_sockaddr(s: &SocketAddr) -> Self37 pub fn from_sockaddr(s: &SocketAddr) -> Self {
38 match s {
39 SocketAddr::V4(_) => InetVersion::V4,
40 SocketAddr::V6(_) => InetVersion::V6,
41 }
42 }
43 }
44
45 impl From<InetVersion> for sa_family_t {
from(v: InetVersion) -> sa_family_t46 fn from(v: InetVersion) -> sa_family_t {
47 match v {
48 InetVersion::V4 => AF_INET as sa_family_t,
49 InetVersion::V6 => AF_INET6 as sa_family_t,
50 }
51 }
52 }
53
sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in54 fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
55 sockaddr_in {
56 sin_family: AF_INET as sa_family_t,
57 sin_port: s.port().to_be(),
58 sin_addr: in_addr {
59 s_addr: u32::from_ne_bytes(s.ip().octets()),
60 },
61 sin_zero: [0; 8],
62 }
63 }
64
sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in665 fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
66 sockaddr_in6 {
67 sin6_family: AF_INET6 as sa_family_t,
68 sin6_port: s.port().to_be(),
69 sin6_flowinfo: 0,
70 sin6_addr: in6_addr {
71 s6_addr: s.ip().octets(),
72 },
73 sin6_scope_id: 0,
74 }
75 }
76
77 /// A TCP socket.
78 ///
79 /// Do not use this class unless you need to change socket options or query the
80 /// state of the socket prior to calling listen or connect. Instead use either TcpStream or
81 /// TcpListener.
82 #[derive(Debug)]
83 pub struct TcpSocket {
84 inet_version: InetVersion,
85 fd: RawFd,
86 }
87
88 impl TcpSocket {
new(inet_version: InetVersion) -> io::Result<Self>89 pub fn new(inet_version: InetVersion) -> io::Result<Self> {
90 let fd = unsafe {
91 libc::socket(
92 Into::<sa_family_t>::into(inet_version) as c_int,
93 SOCK_STREAM | SOCK_CLOEXEC,
94 0,
95 )
96 };
97 if fd < 0 {
98 Err(io::Error::last_os_error())
99 } else {
100 Ok(TcpSocket { inet_version, fd })
101 }
102 }
103
bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()>104 pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
105 let sockaddr = addr
106 .to_socket_addrs()
107 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
108 .next()
109 .unwrap();
110
111 let ret = match sockaddr {
112 SocketAddr::V4(a) => {
113 let sin = sockaddrv4_to_lib_c(&a);
114 // Safe because this doesn't modify any memory and we check the return value.
115 unsafe {
116 libc::bind(
117 self.fd,
118 &sin as *const sockaddr_in as *const sockaddr,
119 size_of::<sockaddr_in>() as socklen_t,
120 )
121 }
122 }
123 SocketAddr::V6(a) => {
124 let sin6 = sockaddrv6_to_lib_c(&a);
125 // Safe because this doesn't modify any memory and we check the return value.
126 unsafe {
127 libc::bind(
128 self.fd,
129 &sin6 as *const sockaddr_in6 as *const sockaddr,
130 size_of::<sockaddr_in6>() as socklen_t,
131 )
132 }
133 }
134 };
135 if ret < 0 {
136 let bind_err = io::Error::last_os_error();
137 Err(bind_err)
138 } else {
139 Ok(())
140 }
141 }
142
connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream>143 pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
144 let sockaddr = addr
145 .to_socket_addrs()
146 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
147 .next()
148 .unwrap();
149
150 let ret = match sockaddr {
151 SocketAddr::V4(a) => {
152 let sin = sockaddrv4_to_lib_c(&a);
153 // Safe because this doesn't modify any memory and we check the return value.
154 unsafe {
155 libc::connect(
156 self.fd,
157 &sin as *const sockaddr_in as *const sockaddr,
158 size_of::<sockaddr_in>() as socklen_t,
159 )
160 }
161 }
162 SocketAddr::V6(a) => {
163 let sin6 = sockaddrv6_to_lib_c(&a);
164 // Safe because this doesn't modify any memory and we check the return value.
165 unsafe {
166 libc::connect(
167 self.fd,
168 &sin6 as *const sockaddr_in6 as *const sockaddr,
169 size_of::<sockaddr_in>() as socklen_t,
170 )
171 }
172 }
173 };
174
175 if ret < 0 {
176 let connect_err = io::Error::last_os_error();
177 Err(connect_err)
178 } else {
179 // Safe because the ownership of the raw fd is released from self and taken over by the
180 // new TcpStream.
181 Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
182 }
183 }
184
listen(self) -> io::Result<TcpListener>185 pub fn listen(self) -> io::Result<TcpListener> {
186 // Safe because this doesn't modify any memory and we check the return value.
187 let ret = unsafe { libc::listen(self.fd, 1) };
188 if ret < 0 {
189 let listen_err = io::Error::last_os_error();
190 Err(listen_err)
191 } else {
192 // Safe because the ownership of the raw fd is released from self and taken over by the
193 // new TcpListener.
194 Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
195 }
196 }
197
198 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u16>199 pub fn local_port(&self) -> io::Result<u16> {
200 match self.inet_version {
201 InetVersion::V4 => {
202 let mut sin = sockaddr_in {
203 sin_family: 0,
204 sin_port: 0,
205 sin_addr: in_addr { s_addr: 0 },
206 sin_zero: [0; 8],
207 };
208
209 // Safe because we give a valid pointer for addrlen and check the length.
210 let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
211 let ret = unsafe {
212 // Get the socket address that was actually bound.
213 libc::getsockname(
214 self.fd,
215 &mut sin as *mut sockaddr_in as *mut sockaddr,
216 &mut addrlen as *mut socklen_t,
217 )
218 };
219 if ret < 0 {
220 let getsockname_err = io::Error::last_os_error();
221 Err(getsockname_err)
222 } else {
223 // If this doesn't match, it's not safe to get the port out of the sockaddr.
224 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
225
226 Ok(u16::from_be(sin.sin_port))
227 }
228 }
229 InetVersion::V6 => {
230 let mut sin6 = sockaddr_in6 {
231 sin6_family: 0,
232 sin6_port: 0,
233 sin6_flowinfo: 0,
234 sin6_addr: in6_addr { s6_addr: [0; 16] },
235 sin6_scope_id: 0,
236 };
237
238 // Safe because we give a valid pointer for addrlen and check the length.
239 let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
240 let ret = unsafe {
241 // Get the socket address that was actually bound.
242 libc::getsockname(
243 self.fd,
244 &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
245 &mut addrlen as *mut socklen_t,
246 )
247 };
248 if ret < 0 {
249 let getsockname_err = io::Error::last_os_error();
250 Err(getsockname_err)
251 } else {
252 // If this doesn't match, it's not safe to get the port out of the sockaddr.
253 assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
254
255 Ok(u16::from_be(sin6.sin6_port))
256 }
257 }
258 }
259 }
260 }
261
262 impl IntoRawFd for TcpSocket {
into_raw_fd(self) -> RawFd263 fn into_raw_fd(self) -> RawFd {
264 let fd = self.fd;
265 mem::forget(self);
266 fd
267 }
268 }
269
270 impl AsRawFd for TcpSocket {
as_raw_fd(&self) -> RawFd271 fn as_raw_fd(&self) -> RawFd {
272 self.fd
273 }
274 }
275
276 impl Drop for TcpSocket {
drop(&mut self)277 fn drop(&mut self) {
278 // Safe because this doesn't modify any memory and we are the only
279 // owner of the file descriptor.
280 unsafe { libc::close(self.fd) };
281 }
282 }
283
284 // Offset of sun_path in structure sockaddr_un.
sun_path_offset() -> usize285 fn sun_path_offset() -> usize {
286 // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
287 #[allow(clippy::zero_ptr)]
288 let addr = 0 as *const libc::sockaddr_un;
289 // Safe because we only use the dereference to create a pointer to the desired field in
290 // calculating the offset.
291 unsafe { &(*addr).sun_path as *const _ as usize }
292 }
293
294 // Return `sockaddr_un` for a given `path`
sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)>295 fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::socklen_t)> {
296 let mut addr = libc::sockaddr_un {
297 sun_family: libc::AF_UNIX as libc::sa_family_t,
298 sun_path: [0; 108],
299 };
300
301 // Check if the input path is valid. Since
302 // * The pathname in sun_path should be null-terminated.
303 // * The length of the pathname, including the terminating null byte,
304 // should not exceed the size of sun_path.
305 //
306 // and our input is a `Path`, we only need to check
307 // * If the string size of `Path` should less than sizeof(sun_path)
308 // and make sure `sun_path` ends with '\0' by initialized the sun_path with zeros.
309 //
310 // Empty path name is valid since abstract socket address has sun_paht[0] = '\0'
311 let bytes = path.as_ref().as_os_str().as_bytes();
312 if bytes.len() >= addr.sun_path.len() {
313 return Err(io::Error::new(
314 io::ErrorKind::InvalidInput,
315 "Input path size should be less than the length of sun_path.",
316 ));
317 };
318
319 // Copy data from `path` to `addr.sun_path`
320 for (dst, src) in addr.sun_path.iter_mut().zip(bytes) {
321 *dst = *src as libc::c_char;
322 }
323
324 // The addrlen argument that describes the enclosing sockaddr_un structure
325 // should have a value of at least:
326 //
327 // offsetof(struct sockaddr_un, sun_path) + strlen(addr.sun_path) + 1
328 //
329 // or, more simply, addrlen can be specified as sizeof(struct sockaddr_un).
330 let len = sun_path_offset() + bytes.len() + 1;
331 Ok((addr, len as libc::socklen_t))
332 }
333
334 /// A Unix `SOCK_SEQPACKET` socket point to given `path`
335 #[derive(Serialize, Deserialize)]
336 pub struct UnixSeqpacket {
337 #[serde(with = "crate::with_raw_descriptor")]
338 fd: RawFd,
339 }
340
341 impl UnixSeqpacket {
342 /// Open a `SOCK_SEQPACKET` connection to socket named by `path`.
343 ///
344 /// # Arguments
345 /// * `path` - Path to `SOCK_SEQPACKET` socket
346 ///
347 /// # Returns
348 /// A `UnixSeqpacket` structure point to the socket
349 ///
350 /// # Errors
351 /// Return `io::Error` when error occurs.
connect<P: AsRef<Path>>(path: P) -> io::Result<Self>352 pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<Self> {
353 // Safe socket initialization since we handle the returned error.
354 let fd = unsafe {
355 match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
356 -1 => return Err(io::Error::last_os_error()),
357 fd => fd,
358 }
359 };
360
361 let (addr, len) = sockaddr_un(path.as_ref())?;
362 // Safe connect since we handle the error and use the right length generated from
363 // `sockaddr_un`.
364 unsafe {
365 let ret = libc::connect(fd, &addr as *const _ as *const _, len);
366 if ret < 0 {
367 return Err(io::Error::last_os_error());
368 }
369 }
370 Ok(UnixSeqpacket { fd })
371 }
372
373 /// Creates a pair of connected `SOCK_SEQPACKET` sockets.
374 ///
375 /// Both returned file descriptors have the `CLOEXEC` flag set.s
pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)>376 pub fn pair() -> io::Result<(UnixSeqpacket, UnixSeqpacket)> {
377 let mut fds = [0, 0];
378 unsafe {
379 // Safe because we give enough space to store all the fds and we check the return value.
380 let ret = libc::socketpair(
381 libc::AF_UNIX,
382 libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
383 0,
384 &mut fds[0],
385 );
386 if ret == 0 {
387 Ok((
388 UnixSeqpacket::from_raw_fd(fds[0]),
389 UnixSeqpacket::from_raw_fd(fds[1]),
390 ))
391 } else {
392 Err(io::Error::last_os_error())
393 }
394 }
395 }
396
397 /// Clone the underlying FD.
try_clone(&self) -> io::Result<Self>398 pub fn try_clone(&self) -> io::Result<Self> {
399 // Safe because this doesn't modify any memory and we check the return value.
400 let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
401 if fd < 0 {
402 Err(io::Error::last_os_error())
403 } else {
404 Ok(Self { fd })
405 }
406 }
407
408 /// Gets the number of bytes that can be read from this socket without blocking.
get_readable_bytes(&self) -> io::Result<usize>409 pub fn get_readable_bytes(&self) -> io::Result<usize> {
410 let mut byte_count = 0i32;
411 let ret = unsafe { libc::ioctl(self.fd, libc::FIONREAD, &mut byte_count) };
412 if ret < 0 {
413 Err(io::Error::last_os_error())
414 } else {
415 Ok(byte_count as usize)
416 }
417 }
418
419 /// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
420 /// respecting the blocking and timeout settings of the underlying socket.
next_packet_size(&self) -> io::Result<usize>421 pub fn next_packet_size(&self) -> io::Result<usize> {
422 #[cfg(not(debug_assertions))]
423 let buf = null_mut();
424 // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
425 // This only matters for running the unit tests for a non-native architecture. See the
426 // upstream thread for the qemu fix:
427 // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
428 #[cfg(debug_assertions)]
429 let buf = &mut 0 as *mut _ as *mut _;
430
431 // This form of recvfrom doesn't modify any data because all null pointers are used. We only
432 // use the return value and check for errors on an FD owned by this structure.
433 let ret = unsafe {
434 recvfrom(
435 self.fd,
436 buf,
437 0,
438 MSG_TRUNC | MSG_PEEK,
439 null_mut(),
440 null_mut(),
441 )
442 };
443 if ret < 0 {
444 Err(io::Error::last_os_error())
445 } else {
446 Ok(ret as usize)
447 }
448 }
449
450 /// Write data from a given buffer to the socket fd
451 ///
452 /// # Arguments
453 /// * `buf` - A reference to the data buffer.
454 ///
455 /// # Returns
456 /// * `usize` - The size of bytes written to the buffer.
457 ///
458 /// # Errors
459 /// Returns error when `libc::write` failed.
send(&self, buf: &[u8]) -> io::Result<usize>460 pub fn send(&self, buf: &[u8]) -> io::Result<usize> {
461 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
462 unsafe {
463 let ret = libc::write(self.fd, buf.as_ptr() as *const _, buf.len());
464 if ret < 0 {
465 Err(io::Error::last_os_error())
466 } else {
467 Ok(ret as usize)
468 }
469 }
470 }
471
472 /// Read data from the socket fd to a given buffer
473 ///
474 /// # Arguments
475 /// * `buf` - A mut reference to the data buffer.
476 ///
477 /// # Returns
478 /// * `usize` - The size of bytes read to the buffer.
479 ///
480 /// # Errors
481 /// Returns error when `libc::read` failed.
recv(&self, buf: &mut [u8]) -> io::Result<usize>482 pub fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
483 // Safe since we make sure the input `count` == `buf.len()` and handle the returned error.
484 unsafe {
485 let ret = libc::read(self.fd, buf.as_mut_ptr() as *mut _, buf.len());
486 if ret < 0 {
487 Err(io::Error::last_os_error())
488 } else {
489 Ok(ret as usize)
490 }
491 }
492 }
493
494 /// Read data from the socket fd to a given `Vec`, resizing it to the received packet's size.
495 ///
496 /// # Arguments
497 /// * `buf` - A mut reference to a `Vec` to resize and read into.
498 ///
499 /// # Errors
500 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()>501 pub fn recv_to_vec(&self, buf: &mut Vec<u8>) -> io::Result<()> {
502 let packet_size = self.next_packet_size()?;
503 buf.resize(packet_size, 0);
504 let read_bytes = self.recv(buf)?;
505 buf.resize(read_bytes, 0);
506 Ok(())
507 }
508
509 /// Read data from the socket fd to a new `Vec`.
510 ///
511 /// # Returns
512 /// * `vec` - A new `Vec` with the entire received packet.
513 ///
514 /// # Errors
515 /// Returns error when `libc::read` or `get_readable_bytes` failed.
recv_as_vec(&self) -> io::Result<Vec<u8>>516 pub fn recv_as_vec(&self) -> io::Result<Vec<u8>> {
517 let mut buf = Vec::new();
518 self.recv_to_vec(&mut buf)?;
519 Ok(buf)
520 }
521
522 /// Read data and fds from the socket fd to a new pair of `Vec`.
523 ///
524 /// # Returns
525 /// * `Vec<u8>` - A new `Vec` with the entire received packet's bytes.
526 /// * `Vec<RawFd>` - A new `Vec` with the entire received packet's fds.
527 ///
528 /// # Errors
529 /// Returns error when `recv_with_fds` or `get_readable_bytes` failed.
recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)>530 pub fn recv_as_vec_with_fds(&self) -> io::Result<(Vec<u8>, Vec<RawFd>)> {
531 let packet_size = self.next_packet_size()?;
532 let mut buf = vec![0; packet_size];
533 let mut fd_buf = vec![-1; SCM_SOCKET_MAX_FD_COUNT];
534 let (read_bytes, read_fds) = self.recv_with_fds(&mut buf, &mut fd_buf)?;
535 buf.resize(read_bytes, 0);
536 fd_buf.resize(read_fds, -1);
537 Ok((buf, fd_buf))
538 }
539
set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()>540 fn set_timeout(&self, timeout: Option<Duration>, kind: libc::c_int) -> io::Result<()> {
541 let timeval = match timeout {
542 Some(t) => {
543 if t.as_secs() == 0 && t.subsec_micros() == 0 {
544 return Err(io::Error::new(
545 io::ErrorKind::InvalidInput,
546 "zero timeout duration is invalid",
547 ));
548 }
549 // subsec_micros fits in i32 because it is defined to be less than one million.
550 let nsec = t.subsec_micros() as i32;
551 libc::timeval {
552 tv_sec: t.as_secs() as libc::time_t,
553 tv_usec: libc::suseconds_t::from(nsec),
554 }
555 }
556 None => libc::timeval {
557 tv_sec: 0,
558 tv_usec: 0,
559 },
560 };
561 // Safe because we own the fd, and the length of the pointer's data is the same as the
562 // passed in length parameter. The level argument is valid, the kind is assumed to be valid,
563 // and the return value is checked.
564 let ret = unsafe {
565 libc::setsockopt(
566 self.fd,
567 libc::SOL_SOCKET,
568 kind,
569 &timeval as *const libc::timeval as *const libc::c_void,
570 mem::size_of::<libc::timeval>() as libc::socklen_t,
571 )
572 };
573 if ret < 0 {
574 Err(io::Error::last_os_error())
575 } else {
576 Ok(())
577 }
578 }
579
580 /// Sets or removes the timeout for read/recv operations on this socket.
set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()>581 pub fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
582 self.set_timeout(timeout, libc::SO_RCVTIMEO)
583 }
584
585 /// Sets or removes the timeout for write/send operations on this socket.
set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()>586 pub fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
587 self.set_timeout(timeout, libc::SO_SNDTIMEO)
588 }
589 }
590
591 impl Drop for UnixSeqpacket {
drop(&mut self)592 fn drop(&mut self) {
593 // Safe if the UnixSeqpacket is created from Self::connect.
594 unsafe {
595 libc::close(self.fd);
596 }
597 }
598 }
599
600 impl FromRawFd for UnixSeqpacket {
601 // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self602 unsafe fn from_raw_fd(fd: RawFd) -> Self {
603 Self { fd }
604 }
605 }
606
607 impl AsRawFd for UnixSeqpacket {
as_raw_fd(&self) -> RawFd608 fn as_raw_fd(&self) -> RawFd {
609 self.fd
610 }
611 }
612
613 impl AsRawFd for &UnixSeqpacket {
as_raw_fd(&self) -> RawFd614 fn as_raw_fd(&self) -> RawFd {
615 self.fd
616 }
617 }
618
619 impl AsRawDescriptor for UnixSeqpacket {
as_raw_descriptor(&self) -> RawDescriptor620 fn as_raw_descriptor(&self) -> RawDescriptor {
621 self.fd
622 }
623 }
624
625 /// Like a `UnixListener` but for accepting `UnixSeqpacket` type sockets.
626 pub struct UnixSeqpacketListener {
627 fd: RawFd,
628 }
629
630 impl UnixSeqpacketListener {
631 /// Creates a new `UnixSeqpacketListener` bound to the given path.
bind<P: AsRef<Path>>(path: P) -> io::Result<Self>632 pub fn bind<P: AsRef<Path>>(path: P) -> io::Result<Self> {
633 // Safe socket initialization since we handle the returned error.
634 let fd = unsafe {
635 match libc::socket(libc::AF_UNIX, libc::SOCK_SEQPACKET, 0) {
636 -1 => return Err(io::Error::last_os_error()),
637 fd => fd,
638 }
639 };
640
641 let (addr, len) = sockaddr_un(path.as_ref())?;
642 // Safe connect since we handle the error and use the right length generated from
643 // `sockaddr_un`.
644 unsafe {
645 let ret = handle_eintr_errno!(libc::bind(fd, &addr as *const _ as *const _, len));
646 if ret < 0 {
647 return Err(io::Error::last_os_error());
648 }
649 let ret = handle_eintr_errno!(libc::listen(fd, 128));
650 if ret < 0 {
651 return Err(io::Error::last_os_error());
652 }
653 }
654 Ok(UnixSeqpacketListener { fd })
655 }
656
657 /// Blocks for and accepts a new incoming connection and returns the socket associated with that
658 /// connection.
659 ///
660 /// The returned socket has the close-on-exec flag set.
accept(&self) -> io::Result<UnixSeqpacket>661 pub fn accept(&self) -> io::Result<UnixSeqpacket> {
662 // Safe because we own this fd and the kernel will not write to null pointers.
663 let ret = unsafe { libc::accept4(self.fd, null_mut(), null_mut(), libc::SOCK_CLOEXEC) };
664 if ret < 0 {
665 return Err(io::Error::last_os_error());
666 }
667 // Safe because we checked the return value of accept. Therefore, the return value must be a
668 // valid socket.
669 Ok(unsafe { UnixSeqpacket::from_raw_fd(ret) })
670 }
671
672 /// Gets the path that this listener is bound to.
path(&self) -> io::Result<PathBuf>673 pub fn path(&self) -> io::Result<PathBuf> {
674 let mut addr = libc::sockaddr_un {
675 sun_family: libc::AF_UNIX as libc::sa_family_t,
676 sun_path: [0; 108],
677 };
678 let sun_path_offset = (&addr.sun_path as *const _ as usize
679 - &addr.sun_family as *const _ as usize)
680 as libc::socklen_t;
681 let mut len = mem::size_of::<libc::sockaddr_un>() as libc::socklen_t;
682 // Safe because the length given matches the length of the data of the given pointer, and we
683 // check the return value.
684 let ret = unsafe {
685 handle_eintr_errno!(libc::getsockname(
686 self.fd,
687 &mut addr as *mut libc::sockaddr_un as *mut libc::sockaddr,
688 &mut len
689 ))
690 };
691 if ret < 0 {
692 return Err(io::Error::last_os_error());
693 }
694 if addr.sun_family != libc::AF_UNIX as libc::sa_family_t
695 || addr.sun_path[0] == 0
696 || len < 1 + sun_path_offset
697 {
698 return Err(io::Error::new(
699 io::ErrorKind::InvalidInput,
700 "getsockname on socket returned invalid value",
701 ));
702 }
703
704 let path_os_str = OsString::from_vec(
705 addr.sun_path[..(len - sun_path_offset - 1) as usize]
706 .iter()
707 .map(|&c| c as _)
708 .collect(),
709 );
710 Ok(path_os_str.into())
711 }
712 }
713
714 impl Drop for UnixSeqpacketListener {
drop(&mut self)715 fn drop(&mut self) {
716 // Safe if the UnixSeqpacketListener is created from Self::listen.
717 unsafe {
718 libc::close(self.fd);
719 }
720 }
721 }
722
723 impl FromRawFd for UnixSeqpacketListener {
724 // Unsafe in drop function
from_raw_fd(fd: RawFd) -> Self725 unsafe fn from_raw_fd(fd: RawFd) -> Self {
726 Self { fd }
727 }
728 }
729
730 impl AsRawFd for UnixSeqpacketListener {
as_raw_fd(&self) -> RawFd731 fn as_raw_fd(&self) -> RawFd {
732 self.fd
733 }
734 }
735
736 /// Used to attempt to clean up a `UnixSeqpacketListener` after it is dropped.
737 pub struct UnlinkUnixSeqpacketListener(pub UnixSeqpacketListener);
738 impl AsRef<UnixSeqpacketListener> for UnlinkUnixSeqpacketListener {
as_ref(&self) -> &UnixSeqpacketListener739 fn as_ref(&self) -> &UnixSeqpacketListener {
740 &self.0
741 }
742 }
743
744 impl AsRawFd for UnlinkUnixSeqpacketListener {
as_raw_fd(&self) -> RawFd745 fn as_raw_fd(&self) -> RawFd {
746 self.0.as_raw_fd()
747 }
748 }
749
750 impl Deref for UnlinkUnixSeqpacketListener {
751 type Target = UnixSeqpacketListener;
deref(&self) -> &Self::Target752 fn deref(&self) -> &Self::Target {
753 &self.0
754 }
755 }
756
757 impl Drop for UnlinkUnixSeqpacketListener {
drop(&mut self)758 fn drop(&mut self) {
759 if let Ok(path) = self.0.path() {
760 if let Err(e) = remove_file(path) {
761 warn!("failed to remove control socket file: {:?}", e);
762 }
763 }
764 }
765 }
766
767 #[cfg(test)]
768 mod tests {
769 use super::*;
770 use std::env;
771 use std::io::ErrorKind;
772 use std::path::PathBuf;
773
tmpdir() -> PathBuf774 fn tmpdir() -> PathBuf {
775 env::temp_dir()
776 }
777
778 #[test]
sockaddr_un_zero_length_input()779 fn sockaddr_un_zero_length_input() {
780 let _res = sockaddr_un(Path::new("")).expect("sockaddr_un failed");
781 }
782
783 #[test]
sockaddr_un_long_input_err()784 fn sockaddr_un_long_input_err() {
785 let res = sockaddr_un(Path::new(&"a".repeat(108)));
786 assert!(res.is_err());
787 }
788
789 #[test]
sockaddr_un_long_input_pass()790 fn sockaddr_un_long_input_pass() {
791 let _res = sockaddr_un(Path::new(&"a".repeat(107))).expect("sockaddr_un failed");
792 }
793
794 #[test]
sockaddr_un_len_check()795 fn sockaddr_un_len_check() {
796 let (_addr, len) = sockaddr_un(Path::new(&"a".repeat(50))).expect("sockaddr_un failed");
797 assert_eq!(len, (sun_path_offset() + 50 + 1) as u32);
798 }
799
800 #[test]
sockaddr_un_pass()801 fn sockaddr_un_pass() {
802 let path_size = 50;
803 let (addr, len) =
804 sockaddr_un(Path::new(&"a".repeat(path_size))).expect("sockaddr_un failed");
805 assert_eq!(len, (sun_path_offset() + path_size + 1) as u32);
806 assert_eq!(addr.sun_family, libc::AF_UNIX as libc::sa_family_t);
807
808 // Check `sun_path` in returned `sockaddr_un`
809 let mut ref_sun_path = [0 as libc::c_char; 108];
810 for path in ref_sun_path.iter_mut().take(path_size) {
811 *path = 'a' as libc::c_char;
812 }
813
814 for (addr_char, ref_char) in addr.sun_path.iter().zip(ref_sun_path.iter()) {
815 assert_eq!(addr_char, ref_char);
816 }
817 }
818
819 #[test]
unix_seqpacket_path_not_exists()820 fn unix_seqpacket_path_not_exists() {
821 let res = UnixSeqpacket::connect("/path/not/exists");
822 assert!(res.is_err());
823 }
824
825 #[test]
unix_seqpacket_listener_path()826 fn unix_seqpacket_listener_path() {
827 let mut socket_path = tmpdir();
828 socket_path.push("unix_seqpacket_listener_path");
829 let listener = UnlinkUnixSeqpacketListener(
830 UnixSeqpacketListener::bind(&socket_path)
831 .expect("failed to create UnixSeqpacketListener"),
832 );
833 let listener_path = listener.path().expect("failed to get socket listener path");
834 assert_eq!(socket_path, listener_path);
835 }
836
837 #[test]
unix_seqpacket_path_exists_pass()838 fn unix_seqpacket_path_exists_pass() {
839 let mut socket_path = tmpdir();
840 socket_path.push("path_to_socket");
841 let _listener = UnlinkUnixSeqpacketListener(
842 UnixSeqpacketListener::bind(&socket_path)
843 .expect("failed to create UnixSeqpacketListener"),
844 );
845 let _res =
846 UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
847 }
848
849 #[test]
unix_seqpacket_path_listener_accept()850 fn unix_seqpacket_path_listener_accept() {
851 let mut socket_path = tmpdir();
852 socket_path.push("path_listerner_accept");
853 let listener = UnlinkUnixSeqpacketListener(
854 UnixSeqpacketListener::bind(&socket_path)
855 .expect("failed to create UnixSeqpacketListener"),
856 );
857 let s1 =
858 UnixSeqpacket::connect(socket_path.as_path()).expect("UnixSeqpacket::connect failed");
859
860 let s2 = listener.accept().expect("UnixSeqpacket::accept failed");
861
862 let data1 = &[0, 1, 2, 3, 4];
863 let data2 = &[10, 11, 12, 13, 14];
864 s2.send(data2).expect("failed to send data2");
865 s1.send(data1).expect("failed to send data1");
866 let recv_data = &mut [0; 5];
867 s2.recv(recv_data).expect("failed to recv data");
868 assert_eq!(data1, recv_data);
869 s1.recv(recv_data).expect("failed to recv data");
870 assert_eq!(data2, recv_data);
871 }
872
873 #[test]
unix_seqpacket_zero_timeout()874 fn unix_seqpacket_zero_timeout() {
875 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
876 // Timeouts less than a microsecond are too small and round to zero.
877 s1.set_read_timeout(Some(Duration::from_nanos(10)))
878 .expect_err("successfully set zero timeout");
879 }
880
881 #[test]
unix_seqpacket_read_timeout()882 fn unix_seqpacket_read_timeout() {
883 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
884 s1.set_read_timeout(Some(Duration::from_millis(1)))
885 .expect("failed to set read timeout for socket");
886 let _ = s1.recv(&mut [0]);
887 }
888
889 #[test]
unix_seqpacket_write_timeout()890 fn unix_seqpacket_write_timeout() {
891 let (s1, _s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
892 s1.set_write_timeout(Some(Duration::from_millis(1)))
893 .expect("failed to set write timeout for socket");
894 }
895
896 #[test]
unix_seqpacket_send_recv()897 fn unix_seqpacket_send_recv() {
898 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
899 let data1 = &[0, 1, 2, 3, 4];
900 let data2 = &[10, 11, 12, 13, 14];
901 s2.send(data2).expect("failed to send data2");
902 s1.send(data1).expect("failed to send data1");
903 let recv_data = &mut [0; 5];
904 s2.recv(recv_data).expect("failed to recv data");
905 assert_eq!(data1, recv_data);
906 s1.recv(recv_data).expect("failed to recv data");
907 assert_eq!(data2, recv_data);
908 }
909
910 #[test]
unix_seqpacket_send_fragments()911 fn unix_seqpacket_send_fragments() {
912 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
913 let data1 = &[0, 1, 2, 3, 4];
914 let data2 = &[10, 11, 12, 13, 14, 15, 16];
915 s1.send(data1).expect("failed to send data1");
916 s1.send(data2).expect("failed to send data2");
917
918 let recv_data = &mut [0; 32];
919 let size = s2.recv(recv_data).expect("failed to recv data");
920 assert_eq!(size, data1.len());
921 assert_eq!(data1, &recv_data[0..size]);
922
923 let size = s2.recv(recv_data).expect("failed to recv data");
924 assert_eq!(size, data2.len());
925 assert_eq!(data2, &recv_data[0..size]);
926 }
927
928 #[test]
unix_seqpacket_get_readable_bytes()929 fn unix_seqpacket_get_readable_bytes() {
930 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
931 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
932 assert_eq!(s2.get_readable_bytes().unwrap(), 0);
933 let data1 = &[0, 1, 2, 3, 4];
934 s1.send(data1).expect("failed to send data");
935
936 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
937 assert_eq!(s2.get_readable_bytes().unwrap(), data1.len());
938
939 let recv_data = &mut [0; 5];
940 s2.recv(recv_data).expect("failed to recv data");
941 assert_eq!(s1.get_readable_bytes().unwrap(), 0);
942 assert_eq!(s2.get_readable_bytes().unwrap(), 0);
943 }
944
945 #[test]
unix_seqpacket_next_packet_size()946 fn unix_seqpacket_next_packet_size() {
947 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
948 let data1 = &[0, 1, 2, 3, 4];
949 s1.send(data1).expect("failed to send data");
950
951 assert_eq!(s2.next_packet_size().unwrap(), 5);
952 s1.set_read_timeout(Some(Duration::from_micros(1)))
953 .expect("failed to set read timeout");
954 assert_eq!(
955 s1.next_packet_size().unwrap_err().kind(),
956 ErrorKind::WouldBlock
957 );
958 drop(s2);
959 assert_eq!(
960 s1.next_packet_size().unwrap_err().kind(),
961 ErrorKind::ConnectionReset
962 );
963 }
964
965 #[test]
unix_seqpacket_recv_to_vec()966 fn unix_seqpacket_recv_to_vec() {
967 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
968 let data1 = &[0, 1, 2, 3, 4];
969 s1.send(data1).expect("failed to send data");
970
971 let recv_data = &mut vec![];
972 s2.recv_to_vec(recv_data).expect("failed to recv data");
973 assert_eq!(recv_data, &mut vec![0, 1, 2, 3, 4]);
974 }
975
976 #[test]
unix_seqpacket_recv_as_vec()977 fn unix_seqpacket_recv_as_vec() {
978 let (s1, s2) = UnixSeqpacket::pair().expect("failed to create socket pair");
979 let data1 = &[0, 1, 2, 3, 4];
980 s1.send(data1).expect("failed to send data");
981
982 let recv_data = s2.recv_as_vec().expect("failed to recv data");
983 assert_eq!(recv_data, vec![0, 1, 2, 3, 4]);
984 }
985 }
986