• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Structs for Unix Domain Socket listener and endpoint.
5 
6 #![allow(dead_code)]
7 
8 use std::io::ErrorKind;
9 use std::marker::PhantomData;
10 use std::os::unix::io::{AsRawFd, RawFd};
11 use std::os::unix::net::{UnixListener, UnixStream};
12 use std::path::{Path, PathBuf};
13 use std::{mem, slice};
14 
15 use libc::{c_void, iovec};
16 use sys_util::ScmSocket;
17 
18 use super::message::*;
19 use super::{Error, Result};
20 
21 /// Unix domain socket listener for accepting incoming connections.
22 pub struct Listener {
23     fd: UnixListener,
24     path: PathBuf,
25 }
26 
27 impl Listener {
28     /// Create a unix domain socket listener.
29     ///
30     /// # Return:
31     /// * - the new Listener object on success.
32     /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>33     pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
34         if unlink {
35             let _ = std::fs::remove_file(&path);
36         }
37         let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
38         Ok(Listener {
39             fd,
40             path: path.as_ref().to_owned(),
41         })
42     }
43 
44     /// Accept an incoming connection.
45     ///
46     /// # Return:
47     /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
48     /// * - None: no incoming connection available.
49     /// * - SocketError: errors from accept().
accept(&self) -> Result<Option<UnixStream>>50     pub fn accept(&self) -> Result<Option<UnixStream>> {
51         loop {
52             match self.fd.accept() {
53                 Ok((socket, _addr)) => return Ok(Some(socket)),
54                 Err(e) => {
55                     match e.kind() {
56                         // No incoming connection available.
57                         ErrorKind::WouldBlock => return Ok(None),
58                         // New connection closed by peer.
59                         ErrorKind::ConnectionAborted => return Ok(None),
60                         // Interrupted by signals, retry
61                         ErrorKind::Interrupted => continue,
62                         _ => return Err(Error::SocketError(e)),
63                     }
64                 }
65             }
66         }
67     }
68 
69     /// Change blocking status on the listener.
70     ///
71     /// # Return:
72     /// * - () on success.
73     /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>74     pub fn set_nonblocking(&self, block: bool) -> Result<()> {
75         self.fd.set_nonblocking(block).map_err(Error::SocketError)
76     }
77 }
78 
79 impl AsRawFd for Listener {
as_raw_fd(&self) -> RawFd80     fn as_raw_fd(&self) -> RawFd {
81         self.fd.as_raw_fd()
82     }
83 }
84 
85 impl Drop for Listener {
drop(&mut self)86     fn drop(&mut self) {
87         let _ = std::fs::remove_file(&self.path);
88     }
89 }
90 
91 /// Unix domain socket endpoint for vhost-user connection.
92 pub(super) struct Endpoint<R: Req> {
93     sock: UnixStream,
94     _r: PhantomData<R>,
95 }
96 
97 impl<R: Req> Endpoint<R> {
98     /// Create a new stream by connecting to server at `str`.
99     ///
100     /// # Return:
101     /// * - the new Endpoint object on success.
102     /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>103     pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
104         let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
105         Ok(Self::from_stream(sock))
106     }
107 
108     /// Create an endpoint from a stream object.
from_stream(sock: UnixStream) -> Self109     pub fn from_stream(sock: UnixStream) -> Self {
110         Endpoint {
111             sock,
112             _r: PhantomData,
113         }
114     }
115 
116     /// Sends bytes from scatter-gather vectors over the socket with optional attached file
117     /// descriptors.
118     ///
119     /// # Return:
120     /// * - number of bytes sent on success
121     /// * - SocketRetry: temporary error caused by signals or short of resources.
122     /// * - SocketBroken: the underline socket is broken.
123     /// * - SocketError: other socket related errors.
send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>124     pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
125         let rfds = match fds {
126             Some(rfds) => rfds,
127             _ => &[],
128         };
129         self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into)
130     }
131 
132     /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
133     /// descriptors. Will loop until all data has been transfered.
134     ///
135     /// # Return:
136     /// * - number of bytes sent on success
137     /// * - SocketBroken: the underline socket is broken.
138     /// * - SocketError: other socket related errors.
send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>139     pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
140         let mut data_sent = 0;
141         let mut data_total = 0;
142         let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
143         for len in &iov_lens {
144             data_total += len;
145         }
146 
147         while (data_total - data_sent) > 0 {
148             let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
149             let iov = &iovs[nr_skip][offset..];
150 
151             let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
152             let sfds = if data_sent == 0 { fds } else { None };
153 
154             let sent = self.send_iovec(data, sfds);
155             match sent {
156                 Ok(0) => return Ok(data_sent),
157                 Ok(n) => data_sent += n,
158                 Err(e) => match e {
159                     Error::SocketRetry(_) => {}
160                     _ => return Err(e),
161                 },
162             }
163         }
164         Ok(data_sent)
165     }
166 
167     /// Sends bytes from a slice over the socket with optional attached file descriptors.
168     ///
169     /// # Return:
170     /// * - number of bytes sent on success
171     /// * - SocketRetry: temporary error caused by signals or short of resources.
172     /// * - SocketBroken: the underline socket is broken.
173     /// * - SocketError: other socket related errors.
send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize>174     pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
175         self.send_iovec(&[data], fds)
176     }
177 
178     /// Sends a header-only message with optional attached file descriptors.
179     ///
180     /// # Return:
181     /// * - number of bytes sent on success
182     /// * - SocketRetry: temporary error caused by signals or short of resources.
183     /// * - SocketBroken: the underline socket is broken.
184     /// * - SocketError: other socket related errors.
185     /// * - PartialMessage: received a partial message.
send_header( &mut self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawFd]>, ) -> Result<()>186     pub fn send_header(
187         &mut self,
188         hdr: &VhostUserMsgHeader<R>,
189         fds: Option<&[RawFd]>,
190     ) -> Result<()> {
191         // Safe because there can't be other mutable referance to hdr.
192         let iovs = unsafe {
193             [slice::from_raw_parts(
194                 hdr as *const VhostUserMsgHeader<R> as *const u8,
195                 mem::size_of::<VhostUserMsgHeader<R>>(),
196             )]
197         };
198         let bytes = self.send_iovec_all(&iovs[..], fds)?;
199         if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
200             return Err(Error::PartialMessage);
201         }
202         Ok(())
203     }
204 
205     /// Send a message with header and body. Optional file descriptors may be attached to
206     /// the message.
207     ///
208     /// # Return:
209     /// * - number of bytes sent on success
210     /// * - SocketRetry: temporary error caused by signals or short of resources.
211     /// * - SocketBroken: the underline socket is broken.
212     /// * - SocketError: other socket related errors.
213     /// * - PartialMessage: received a partial message.
send_message<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawFd]>, ) -> Result<()>214     pub fn send_message<T: Sized>(
215         &mut self,
216         hdr: &VhostUserMsgHeader<R>,
217         body: &T,
218         fds: Option<&[RawFd]>,
219     ) -> Result<()> {
220         if mem::size_of::<T>() > MAX_MSG_SIZE {
221             return Err(Error::OversizedMsg);
222         }
223         // Safe because there can't be other mutable referance to hdr and body.
224         let iovs = unsafe {
225             [
226                 slice::from_raw_parts(
227                     hdr as *const VhostUserMsgHeader<R> as *const u8,
228                     mem::size_of::<VhostUserMsgHeader<R>>(),
229                 ),
230                 slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
231             ]
232         };
233         let bytes = self.send_iovec_all(&iovs[..], fds)?;
234         if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
235             return Err(Error::PartialMessage);
236         }
237         Ok(())
238     }
239 
240     /// Send a message with header, body and payload. Optional file descriptors
241     /// may also be attached to the message.
242     ///
243     /// # Return:
244     /// * - number of bytes sent on success
245     /// * - SocketRetry: temporary error caused by signals or short of resources.
246     /// * - SocketBroken: the underline socket is broken.
247     /// * - SocketError: other socket related errors.
248     /// * - OversizedMsg: message size is too big.
249     /// * - PartialMessage: received a partial message.
250     /// * - IncorrectFds: wrong number of attached fds.
send_message_with_payload<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()>251     pub fn send_message_with_payload<T: Sized>(
252         &mut self,
253         hdr: &VhostUserMsgHeader<R>,
254         body: &T,
255         payload: &[u8],
256         fds: Option<&[RawFd]>,
257     ) -> Result<()> {
258         let len = payload.len();
259         if mem::size_of::<T>() > MAX_MSG_SIZE {
260             return Err(Error::OversizedMsg);
261         }
262         if len > MAX_MSG_SIZE - mem::size_of::<T>() {
263             return Err(Error::OversizedMsg);
264         }
265         if let Some(fd_arr) = fds {
266             if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
267                 return Err(Error::IncorrectFds);
268             }
269         }
270 
271         // Safe because there can't be other mutable reference to hdr, body and payload.
272         let iovs = unsafe {
273             [
274                 slice::from_raw_parts(
275                     hdr as *const VhostUserMsgHeader<R> as *const u8,
276                     mem::size_of::<VhostUserMsgHeader<R>>(),
277                 ),
278                 slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
279                 slice::from_raw_parts(payload.as_ptr() as *const u8, len),
280             ]
281         };
282         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
283         let len = self.send_iovec_all(&iovs, fds)?;
284         if len != total {
285             return Err(Error::PartialMessage);
286         }
287         Ok(())
288     }
289 
290     /// Reads bytes from the socket into the given scatter/gather vectors.
291     ///
292     /// # Return:
293     /// * - (number of bytes received, buf) on success
294     /// * - SocketRetry: temporary error caused by signals or short of resources.
295     /// * - SocketBroken: the underline socket is broken.
296     /// * - SocketError: other socket related errors.
recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)>297     pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
298         let mut rbuf = vec![0u8; len];
299         let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?;
300         Ok((bytes, rbuf))
301     }
302 
303     /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
304     /// file descriptors.
305     ///
306     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
307     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
308     /// sender sending a message with some file descriptors attached. To successfully receive those
309     /// attached file descriptors, the receiver must obey following rules:
310     ///   1) file descriptors are attached to a message.
311     ///   2) message(packet) boundaries must be respected on the receive side.
312     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
313     /// attached file descriptors will get lost.
314     ///
315     /// # Return:
316     /// * - (number of bytes received, [received fds]) on success
317     /// * - SocketRetry: temporary error caused by signals or short of resources.
318     /// * - SocketBroken: the underline socket is broken.
319     /// * - SocketError: other socket related errors.
recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)>320     pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
321         let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
322         let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
323         let rfds = match fds {
324             0 => None,
325             n => {
326                 let mut fds = Vec::with_capacity(n);
327                 fds.extend_from_slice(&fd_array[0..n]);
328                 Some(fds)
329             }
330         };
331 
332         Ok((bytes, rfds))
333     }
334 
335     /// Reads all bytes from the socket into the given scatter/gather vectors with optional
336     /// attached file descriptors. Will loop until all data has been transfered.
337     ///
338     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
339     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
340     /// sender sending a message with some file descriptors attached. To successfully receive those
341     /// attached file descriptors, the receiver must obey following rules:
342     ///   1) file descriptors are attached to a message.
343     ///   2) message(packet) boundaries must be respected on the receive side.
344     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
345     /// attached file descriptors will get lost.
346     ///
347     /// # Return:
348     /// * - (number of bytes received, [received fds]) on success
349     /// * - SocketBroken: the underline socket is broken.
350     /// * - SocketError: other socket related errors.
recv_into_iovec_all( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<RawFd>>)>351     pub fn recv_into_iovec_all(
352         &mut self,
353         iovs: &mut [iovec],
354     ) -> Result<(usize, Option<Vec<RawFd>>)> {
355         let mut data_read = 0;
356         let mut data_total = 0;
357         let mut rfds = None;
358         let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
359         for len in &iov_lens {
360             data_total += len;
361         }
362 
363         while (data_total - data_read) > 0 {
364             let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
365             let iov = &mut iovs[nr_skip];
366 
367             let mut data = [
368                 &[iovec {
369                     iov_base: (iov.iov_base as usize + offset) as *mut c_void,
370                     iov_len: iov.iov_len - offset,
371                 }],
372                 &iovs[(nr_skip + 1)..],
373             ]
374             .concat();
375 
376             let res = self.recv_into_iovec(&mut data);
377             match res {
378                 Ok((0, _)) => return Ok((data_read, rfds)),
379                 Ok((n, fds)) => {
380                     if data_read == 0 {
381                         rfds = fds;
382                     }
383                     data_read += n;
384                 }
385                 Err(e) => match e {
386                     Error::SocketRetry(_) => {}
387                     _ => return Err(e),
388                 },
389             }
390         }
391         Ok((data_read, rfds))
392     }
393 
394     /// Reads bytes from the socket into a new buffer with optional attached
395     /// file descriptors. Received file descriptors are set close-on-exec.
396     ///
397     /// # Return:
398     /// * - (number of bytes received, buf, [received fds]) on success.
399     /// * - SocketRetry: temporary error caused by signals or short of resources.
400     /// * - SocketBroken: the underline socket is broken.
401     /// * - SocketError: other socket related errors.
recv_into_buf( &mut self, buf_size: usize, ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)>402     pub fn recv_into_buf(
403         &mut self,
404         buf_size: usize,
405     ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
406         let mut buf = vec![0u8; buf_size];
407         let (bytes, rfds) = {
408             let mut iovs = [iovec {
409                 iov_base: buf.as_mut_ptr() as *mut c_void,
410                 iov_len: buf_size,
411             }];
412             self.recv_into_iovec(&mut iovs)?
413         };
414         Ok((bytes, buf, rfds))
415     }
416 
417     /// Receive a header-only message with optional attached file descriptors.
418     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
419     /// accepted and all other file descriptor will be discard silently.
420     ///
421     /// # Return:
422     /// * - (message header, [received fds]) on success.
423     /// * - SocketRetry: temporary error caused by signals or short of resources.
424     /// * - SocketBroken: the underline socket is broken.
425     /// * - SocketError: other socket related errors.
426     /// * - PartialMessage: received a partial message.
427     /// * - InvalidMessage: received a invalid message.
recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)>428     pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
429         let mut hdr = VhostUserMsgHeader::default();
430         let mut iovs = [iovec {
431             iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
432             iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
433         }];
434         let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
435 
436         if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
437             return Err(Error::PartialMessage);
438         } else if !hdr.is_valid() {
439             return Err(Error::InvalidMessage);
440         }
441 
442         Ok((hdr, rfds))
443     }
444 
445     /// Receive a message with optional attached file descriptors.
446     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
447     /// accepted and all other file descriptor will be discard silently.
448     ///
449     /// # Return:
450     /// * - (message header, message body, [received fds]) on success.
451     /// * - SocketRetry: temporary error caused by signals or short of resources.
452     /// * - SocketBroken: the underline socket is broken.
453     /// * - SocketError: other socket related errors.
454     /// * - PartialMessage: received a partial message.
455     /// * - InvalidMessage: received a invalid message.
recv_body<T: Sized + Default + VhostUserMsgValidator>( &mut self, ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)>456     pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
457         &mut self,
458     ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
459         let mut hdr = VhostUserMsgHeader::default();
460         let mut body: T = Default::default();
461         let mut iovs = [
462             iovec {
463                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
464                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
465             },
466             iovec {
467                 iov_base: (&mut body as *mut T) as *mut c_void,
468                 iov_len: mem::size_of::<T>(),
469             },
470         ];
471         let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
472 
473         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
474         if bytes != total {
475             return Err(Error::PartialMessage);
476         } else if !hdr.is_valid() || !body.is_valid() {
477             return Err(Error::InvalidMessage);
478         }
479 
480         Ok((hdr, body, rfds))
481     }
482 
483     /// Receive a message with header and optional content. Callers need to
484     /// pre-allocate a big enough buffer to receive the message body and
485     /// optional payload. If there are attached file descriptor associated
486     /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
487     /// will be accepted and all other file descriptor will be discard
488     /// silently.
489     ///
490     /// # Return:
491     /// * - (message header, message size, [received fds]) on success.
492     /// * - SocketRetry: temporary error caused by signals or short of resources.
493     /// * - SocketBroken: the underline socket is broken.
494     /// * - SocketError: other socket related errors.
495     /// * - PartialMessage: received a partial message.
496     /// * - InvalidMessage: received a invalid message.
recv_body_into_buf( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)>497     pub fn recv_body_into_buf(
498         &mut self,
499         buf: &mut [u8],
500     ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
501         let mut hdr = VhostUserMsgHeader::default();
502         let mut iovs = [
503             iovec {
504                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
505                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
506             },
507             iovec {
508                 iov_base: buf.as_mut_ptr() as *mut c_void,
509                 iov_len: buf.len(),
510             },
511         ];
512         let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
513 
514         if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
515             return Err(Error::PartialMessage);
516         } else if !hdr.is_valid() {
517             return Err(Error::InvalidMessage);
518         }
519 
520         Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
521     }
522 
523     /// Receive a message with optional payload and attached file descriptors.
524     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
525     /// accepted and all other file descriptor will be discard silently.
526     ///
527     /// # Return:
528     /// * - (message header, message body, size of payload, [received fds]) on success.
529     /// * - SocketRetry: temporary error caused by signals or short of resources.
530     /// * - SocketBroken: the underline socket is broken.
531     /// * - SocketError: other socket related errors.
532     /// * - PartialMessage: received a partial message.
533     /// * - InvalidMessage: received a invalid message.
534     #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)>535     pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
536         &mut self,
537         buf: &mut [u8],
538     ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
539         let mut hdr = VhostUserMsgHeader::default();
540         let mut body: T = Default::default();
541         let mut iovs = [
542             iovec {
543                 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
544                 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
545             },
546             iovec {
547                 iov_base: (&mut body as *mut T) as *mut c_void,
548                 iov_len: mem::size_of::<T>(),
549             },
550             iovec {
551                 iov_base: buf.as_mut_ptr() as *mut c_void,
552                 iov_len: buf.len(),
553             },
554         ];
555         let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
556 
557         let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
558         if bytes < total {
559             return Err(Error::PartialMessage);
560         } else if !hdr.is_valid() || !body.is_valid() {
561             return Err(Error::InvalidMessage);
562         }
563 
564         Ok((hdr, body, bytes - total, rfds))
565     }
566 
567     /// Close all raw file descriptors.
close_rfds(rfds: Option<Vec<RawFd>>)568     pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
569         if let Some(fds) = rfds {
570             for fd in fds {
571                 // safe because the rawfds are valid and we don't care about the result.
572                 let _ = unsafe { libc::close(fd) };
573             }
574         }
575     }
576 }
577 
578 impl<T: Req> AsRawFd for Endpoint<T> {
as_raw_fd(&self) -> RawFd579     fn as_raw_fd(&self) -> RawFd {
580         self.sock.as_raw_fd()
581     }
582 }
583 
584 // Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
585 // For example:
586 //     let iov_lens = vec![4, 4, 5];
587 //     let size = 6;
588 //     assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize)589 fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
590     let mut size = skip_size;
591     let mut nr_skip = 0;
592 
593     for len in iov_lens {
594         if size >= *len {
595             size -= *len;
596             nr_skip += 1;
597         } else {
598             break;
599         }
600     }
601     (nr_skip, size)
602 }
603 
604 #[cfg(test)]
605 mod tests {
606     use super::*;
607     use std::fs::File;
608     use std::io::{Read, Seek, SeekFrom, Write};
609     use std::os::unix::io::FromRawFd;
610     use tempfile::{tempfile, Builder, TempDir};
611 
temp_dir() -> TempDir612     fn temp_dir() -> TempDir {
613         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
614     }
615 
616     #[test]
create_listener()617     fn create_listener() {
618         let dir = temp_dir();
619         let mut path = dir.path().to_owned();
620         path.push("sock");
621         let listener = Listener::new(&path, true).unwrap();
622 
623         assert!(listener.as_raw_fd() > 0);
624     }
625 
626     #[test]
accept_connection()627     fn accept_connection() {
628         let dir = temp_dir();
629         let mut path = dir.path().to_owned();
630         path.push("sock");
631         let listener = Listener::new(&path, true).unwrap();
632         listener.set_nonblocking(true).unwrap();
633 
634         // accept on a fd without incoming connection
635         let conn = listener.accept().unwrap();
636         assert!(conn.is_none());
637     }
638 
639     #[test]
send_data()640     fn send_data() {
641         let dir = temp_dir();
642         let mut path = dir.path().to_owned();
643         path.push("sock");
644         let listener = Listener::new(&path, true).unwrap();
645         listener.set_nonblocking(true).unwrap();
646         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
647         let sock = listener.accept().unwrap().unwrap();
648         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
649 
650         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
651         let mut len = master.send_slice(&buf1[..], None).unwrap();
652         assert_eq!(len, 4);
653         let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
654         assert_eq!(bytes, 4);
655         assert_eq!(&buf1[..], &buf2[..bytes]);
656 
657         len = master.send_slice(&buf1[..], None).unwrap();
658         assert_eq!(len, 4);
659         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
660         assert_eq!(bytes, 2);
661         assert_eq!(&buf1[..2], &buf2[..]);
662         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
663         assert_eq!(bytes, 2);
664         assert_eq!(&buf1[2..], &buf2[..]);
665     }
666 
667     #[test]
send_fd()668     fn send_fd() {
669         let dir = temp_dir();
670         let mut path = dir.path().to_owned();
671         path.push("sock");
672         let listener = Listener::new(&path, true).unwrap();
673         listener.set_nonblocking(true).unwrap();
674         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
675         let sock = listener.accept().unwrap().unwrap();
676         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
677 
678         let mut fd = tempfile().unwrap();
679         write!(fd, "test").unwrap();
680 
681         // Normal case for sending/receiving file descriptors
682         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
683         let len = master
684             .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
685             .unwrap();
686         assert_eq!(len, 4);
687 
688         let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
689         assert_eq!(bytes, 4);
690         assert_eq!(&buf1[..], &buf2[..]);
691         assert!(rfds.is_some());
692         let fds = rfds.unwrap();
693         {
694             assert_eq!(fds.len(), 1);
695             let mut file = unsafe { File::from_raw_fd(fds[0]) };
696             let mut content = String::new();
697             file.seek(SeekFrom::Start(0)).unwrap();
698             file.read_to_string(&mut content).unwrap();
699             assert_eq!(content, "test");
700         }
701 
702         // Following communication pattern should work:
703         // Sending side: data(header, body) with fds
704         // Receiving side: data(header) with fds, data(body)
705         let len = master
706             .send_slice(
707                 &buf1[..],
708                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
709             )
710             .unwrap();
711         assert_eq!(len, 4);
712 
713         let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
714         assert_eq!(bytes, 2);
715         assert_eq!(&buf1[..2], &buf2[..]);
716         assert!(rfds.is_some());
717         let fds = rfds.unwrap();
718         {
719             assert_eq!(fds.len(), 3);
720             let mut file = unsafe { File::from_raw_fd(fds[1]) };
721             let mut content = String::new();
722             file.seek(SeekFrom::Start(0)).unwrap();
723             file.read_to_string(&mut content).unwrap();
724             assert_eq!(content, "test");
725         }
726         let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
727         assert_eq!(bytes, 2);
728         assert_eq!(&buf1[2..], &buf2[..]);
729         assert!(rfds.is_none());
730 
731         // Following communication pattern should not work:
732         // Sending side: data(header, body) with fds
733         // Receiving side: data(header), data(body) with fds
734         let len = master
735             .send_slice(
736                 &buf1[..],
737                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
738             )
739             .unwrap();
740         assert_eq!(len, 4);
741 
742         let (bytes, buf4) = slave.recv_data(2).unwrap();
743         assert_eq!(bytes, 2);
744         assert_eq!(&buf1[..2], &buf4[..]);
745         let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
746         assert_eq!(bytes, 2);
747         assert_eq!(&buf1[2..], &buf2[..]);
748         assert!(rfds.is_none());
749 
750         // Following communication pattern should work:
751         // Sending side: data, data with fds
752         // Receiving side: data, data with fds
753         let len = master.send_slice(&buf1[..], None).unwrap();
754         assert_eq!(len, 4);
755         let len = master
756             .send_slice(
757                 &buf1[..],
758                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
759             )
760             .unwrap();
761         assert_eq!(len, 4);
762 
763         let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
764         assert_eq!(bytes, 4);
765         assert_eq!(&buf1[..], &buf2[..]);
766         assert!(rfds.is_none());
767 
768         let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
769         assert_eq!(bytes, 2);
770         assert_eq!(&buf1[..2], &buf2[..]);
771         assert!(rfds.is_some());
772         let fds = rfds.unwrap();
773         {
774             assert_eq!(fds.len(), 3);
775             let mut file = unsafe { File::from_raw_fd(fds[1]) };
776             let mut content = String::new();
777             file.seek(SeekFrom::Start(0)).unwrap();
778             file.read_to_string(&mut content).unwrap();
779             assert_eq!(content, "test");
780         }
781         let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
782         assert_eq!(bytes, 2);
783         assert_eq!(&buf1[2..], &buf2[..]);
784         assert!(rfds.is_none());
785 
786         // Following communication pattern should not work:
787         // Sending side: data1, data2 with fds
788         // Receiving side: data + partial of data2, left of data2 with fds
789         let len = master.send_slice(&buf1[..], None).unwrap();
790         assert_eq!(len, 4);
791         let len = master
792             .send_slice(
793                 &buf1[..],
794                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
795             )
796             .unwrap();
797         assert_eq!(len, 4);
798 
799         let (bytes, _) = slave.recv_data(5).unwrap();
800         assert_eq!(bytes, 5);
801 
802         let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
803         assert_eq!(bytes, 3);
804         assert!(rfds.is_none());
805 
806         // If the target fd array is too small, extra file descriptors will get lost.
807         let len = master
808             .send_slice(
809                 &buf1[..],
810                 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
811             )
812             .unwrap();
813         assert_eq!(len, 4);
814 
815         let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
816         assert_eq!(bytes, 4);
817         assert!(rfds.is_some());
818 
819         Endpoint::<MasterReq>::close_rfds(rfds);
820         Endpoint::<MasterReq>::close_rfds(None);
821     }
822 
823     #[test]
send_recv()824     fn send_recv() {
825         let dir = temp_dir();
826         let mut path = dir.path().to_owned();
827         path.push("sock");
828         let listener = Listener::new(&path, true).unwrap();
829         listener.set_nonblocking(true).unwrap();
830         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
831         let sock = listener.accept().unwrap().unwrap();
832         let mut slave = Endpoint::<MasterReq>::from_stream(sock);
833 
834         let mut hdr1 =
835             VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
836         hdr1.set_need_reply(true);
837         let features1 = 0x1u64;
838         master.send_message(&hdr1, &features1, None).unwrap();
839 
840         let mut features2 = 0u64;
841         let slice = unsafe {
842             slice::from_raw_parts_mut(
843                 (&mut features2 as *mut u64) as *mut u8,
844                 mem::size_of::<u64>(),
845             )
846         };
847         let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
848         assert_eq!(hdr1, hdr2);
849         assert_eq!(bytes, 8);
850         assert_eq!(features1, features2);
851         assert!(rfds.is_none());
852 
853         master.send_header(&hdr1, None).unwrap();
854         let (hdr2, rfds) = slave.recv_header().unwrap();
855         assert_eq!(hdr1, hdr2);
856         assert!(rfds.is_none());
857     }
858 }
859