1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 //! Structs for Unix Domain Socket listener and endpoint.
5
6 #![allow(dead_code)]
7
8 use std::io::ErrorKind;
9 use std::marker::PhantomData;
10 use std::os::unix::io::{AsRawFd, RawFd};
11 use std::os::unix::net::{UnixListener, UnixStream};
12 use std::path::{Path, PathBuf};
13 use std::{mem, slice};
14
15 use libc::{c_void, iovec};
16 use sys_util::ScmSocket;
17
18 use super::message::*;
19 use super::{Error, Result};
20
21 /// Unix domain socket listener for accepting incoming connections.
22 pub struct Listener {
23 fd: UnixListener,
24 path: PathBuf,
25 }
26
27 impl Listener {
28 /// Create a unix domain socket listener.
29 ///
30 /// # Return:
31 /// * - the new Listener object on success.
32 /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>33 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
34 if unlink {
35 let _ = std::fs::remove_file(&path);
36 }
37 let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
38 Ok(Listener {
39 fd,
40 path: path.as_ref().to_owned(),
41 })
42 }
43
44 /// Accept an incoming connection.
45 ///
46 /// # Return:
47 /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
48 /// * - None: no incoming connection available.
49 /// * - SocketError: errors from accept().
accept(&self) -> Result<Option<UnixStream>>50 pub fn accept(&self) -> Result<Option<UnixStream>> {
51 loop {
52 match self.fd.accept() {
53 Ok((socket, _addr)) => return Ok(Some(socket)),
54 Err(e) => {
55 match e.kind() {
56 // No incoming connection available.
57 ErrorKind::WouldBlock => return Ok(None),
58 // New connection closed by peer.
59 ErrorKind::ConnectionAborted => return Ok(None),
60 // Interrupted by signals, retry
61 ErrorKind::Interrupted => continue,
62 _ => return Err(Error::SocketError(e)),
63 }
64 }
65 }
66 }
67 }
68
69 /// Change blocking status on the listener.
70 ///
71 /// # Return:
72 /// * - () on success.
73 /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>74 pub fn set_nonblocking(&self, block: bool) -> Result<()> {
75 self.fd.set_nonblocking(block).map_err(Error::SocketError)
76 }
77 }
78
79 impl AsRawFd for Listener {
as_raw_fd(&self) -> RawFd80 fn as_raw_fd(&self) -> RawFd {
81 self.fd.as_raw_fd()
82 }
83 }
84
85 impl Drop for Listener {
drop(&mut self)86 fn drop(&mut self) {
87 let _ = std::fs::remove_file(&self.path);
88 }
89 }
90
91 /// Unix domain socket endpoint for vhost-user connection.
92 pub(super) struct Endpoint<R: Req> {
93 sock: UnixStream,
94 _r: PhantomData<R>,
95 }
96
97 impl<R: Req> Endpoint<R> {
98 /// Create a new stream by connecting to server at `str`.
99 ///
100 /// # Return:
101 /// * - the new Endpoint object on success.
102 /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>103 pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
104 let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
105 Ok(Self::from_stream(sock))
106 }
107
108 /// Create an endpoint from a stream object.
from_stream(sock: UnixStream) -> Self109 pub fn from_stream(sock: UnixStream) -> Self {
110 Endpoint {
111 sock,
112 _r: PhantomData,
113 }
114 }
115
116 /// Sends bytes from scatter-gather vectors over the socket with optional attached file
117 /// descriptors.
118 ///
119 /// # Return:
120 /// * - number of bytes sent on success
121 /// * - SocketRetry: temporary error caused by signals or short of resources.
122 /// * - SocketBroken: the underline socket is broken.
123 /// * - SocketError: other socket related errors.
send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>124 pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
125 let rfds = match fds {
126 Some(rfds) => rfds,
127 _ => &[],
128 };
129 self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into)
130 }
131
132 /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
133 /// descriptors. Will loop until all data has been transfered.
134 ///
135 /// # Return:
136 /// * - number of bytes sent on success
137 /// * - SocketBroken: the underline socket is broken.
138 /// * - SocketError: other socket related errors.
send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize>139 pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
140 let mut data_sent = 0;
141 let mut data_total = 0;
142 let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
143 for len in &iov_lens {
144 data_total += len;
145 }
146
147 while (data_total - data_sent) > 0 {
148 let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
149 let iov = &iovs[nr_skip][offset..];
150
151 let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
152 let sfds = if data_sent == 0 { fds } else { None };
153
154 let sent = self.send_iovec(data, sfds);
155 match sent {
156 Ok(0) => return Ok(data_sent),
157 Ok(n) => data_sent += n,
158 Err(e) => match e {
159 Error::SocketRetry(_) => {}
160 _ => return Err(e),
161 },
162 }
163 }
164 Ok(data_sent)
165 }
166
167 /// Sends bytes from a slice over the socket with optional attached file descriptors.
168 ///
169 /// # Return:
170 /// * - number of bytes sent on success
171 /// * - SocketRetry: temporary error caused by signals or short of resources.
172 /// * - SocketBroken: the underline socket is broken.
173 /// * - SocketError: other socket related errors.
send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize>174 pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
175 self.send_iovec(&[data], fds)
176 }
177
178 /// Sends a header-only message with optional attached file descriptors.
179 ///
180 /// # Return:
181 /// * - number of bytes sent on success
182 /// * - SocketRetry: temporary error caused by signals or short of resources.
183 /// * - SocketBroken: the underline socket is broken.
184 /// * - SocketError: other socket related errors.
185 /// * - PartialMessage: received a partial message.
send_header( &mut self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawFd]>, ) -> Result<()>186 pub fn send_header(
187 &mut self,
188 hdr: &VhostUserMsgHeader<R>,
189 fds: Option<&[RawFd]>,
190 ) -> Result<()> {
191 // Safe because there can't be other mutable referance to hdr.
192 let iovs = unsafe {
193 [slice::from_raw_parts(
194 hdr as *const VhostUserMsgHeader<R> as *const u8,
195 mem::size_of::<VhostUserMsgHeader<R>>(),
196 )]
197 };
198 let bytes = self.send_iovec_all(&iovs[..], fds)?;
199 if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
200 return Err(Error::PartialMessage);
201 }
202 Ok(())
203 }
204
205 /// Send a message with header and body. Optional file descriptors may be attached to
206 /// the message.
207 ///
208 /// # Return:
209 /// * - number of bytes sent on success
210 /// * - SocketRetry: temporary error caused by signals or short of resources.
211 /// * - SocketBroken: the underline socket is broken.
212 /// * - SocketError: other socket related errors.
213 /// * - PartialMessage: received a partial message.
send_message<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawFd]>, ) -> Result<()>214 pub fn send_message<T: Sized>(
215 &mut self,
216 hdr: &VhostUserMsgHeader<R>,
217 body: &T,
218 fds: Option<&[RawFd]>,
219 ) -> Result<()> {
220 if mem::size_of::<T>() > MAX_MSG_SIZE {
221 return Err(Error::OversizedMsg);
222 }
223 // Safe because there can't be other mutable referance to hdr and body.
224 let iovs = unsafe {
225 [
226 slice::from_raw_parts(
227 hdr as *const VhostUserMsgHeader<R> as *const u8,
228 mem::size_of::<VhostUserMsgHeader<R>>(),
229 ),
230 slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
231 ]
232 };
233 let bytes = self.send_iovec_all(&iovs[..], fds)?;
234 if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
235 return Err(Error::PartialMessage);
236 }
237 Ok(())
238 }
239
240 /// Send a message with header, body and payload. Optional file descriptors
241 /// may also be attached to the message.
242 ///
243 /// # Return:
244 /// * - number of bytes sent on success
245 /// * - SocketRetry: temporary error caused by signals or short of resources.
246 /// * - SocketBroken: the underline socket is broken.
247 /// * - SocketError: other socket related errors.
248 /// * - OversizedMsg: message size is too big.
249 /// * - PartialMessage: received a partial message.
250 /// * - IncorrectFds: wrong number of attached fds.
send_message_with_payload<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()>251 pub fn send_message_with_payload<T: Sized>(
252 &mut self,
253 hdr: &VhostUserMsgHeader<R>,
254 body: &T,
255 payload: &[u8],
256 fds: Option<&[RawFd]>,
257 ) -> Result<()> {
258 let len = payload.len();
259 if mem::size_of::<T>() > MAX_MSG_SIZE {
260 return Err(Error::OversizedMsg);
261 }
262 if len > MAX_MSG_SIZE - mem::size_of::<T>() {
263 return Err(Error::OversizedMsg);
264 }
265 if let Some(fd_arr) = fds {
266 if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
267 return Err(Error::IncorrectFds);
268 }
269 }
270
271 // Safe because there can't be other mutable reference to hdr, body and payload.
272 let iovs = unsafe {
273 [
274 slice::from_raw_parts(
275 hdr as *const VhostUserMsgHeader<R> as *const u8,
276 mem::size_of::<VhostUserMsgHeader<R>>(),
277 ),
278 slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
279 slice::from_raw_parts(payload.as_ptr() as *const u8, len),
280 ]
281 };
282 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
283 let len = self.send_iovec_all(&iovs, fds)?;
284 if len != total {
285 return Err(Error::PartialMessage);
286 }
287 Ok(())
288 }
289
290 /// Reads bytes from the socket into the given scatter/gather vectors.
291 ///
292 /// # Return:
293 /// * - (number of bytes received, buf) on success
294 /// * - SocketRetry: temporary error caused by signals or short of resources.
295 /// * - SocketBroken: the underline socket is broken.
296 /// * - SocketError: other socket related errors.
recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)>297 pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
298 let mut rbuf = vec![0u8; len];
299 let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?;
300 Ok((bytes, rbuf))
301 }
302
303 /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
304 /// file descriptors.
305 ///
306 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
307 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
308 /// sender sending a message with some file descriptors attached. To successfully receive those
309 /// attached file descriptors, the receiver must obey following rules:
310 /// 1) file descriptors are attached to a message.
311 /// 2) message(packet) boundaries must be respected on the receive side.
312 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
313 /// attached file descriptors will get lost.
314 ///
315 /// # Return:
316 /// * - (number of bytes received, [received fds]) on success
317 /// * - SocketRetry: temporary error caused by signals or short of resources.
318 /// * - SocketBroken: the underline socket is broken.
319 /// * - SocketError: other socket related errors.
recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)>320 pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
321 let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
322 let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
323 let rfds = match fds {
324 0 => None,
325 n => {
326 let mut fds = Vec::with_capacity(n);
327 fds.extend_from_slice(&fd_array[0..n]);
328 Some(fds)
329 }
330 };
331
332 Ok((bytes, rfds))
333 }
334
335 /// Reads all bytes from the socket into the given scatter/gather vectors with optional
336 /// attached file descriptors. Will loop until all data has been transfered.
337 ///
338 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
339 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
340 /// sender sending a message with some file descriptors attached. To successfully receive those
341 /// attached file descriptors, the receiver must obey following rules:
342 /// 1) file descriptors are attached to a message.
343 /// 2) message(packet) boundaries must be respected on the receive side.
344 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
345 /// attached file descriptors will get lost.
346 ///
347 /// # Return:
348 /// * - (number of bytes received, [received fds]) on success
349 /// * - SocketBroken: the underline socket is broken.
350 /// * - SocketError: other socket related errors.
recv_into_iovec_all( &mut self, iovs: &mut [iovec], ) -> Result<(usize, Option<Vec<RawFd>>)>351 pub fn recv_into_iovec_all(
352 &mut self,
353 iovs: &mut [iovec],
354 ) -> Result<(usize, Option<Vec<RawFd>>)> {
355 let mut data_read = 0;
356 let mut data_total = 0;
357 let mut rfds = None;
358 let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
359 for len in &iov_lens {
360 data_total += len;
361 }
362
363 while (data_total - data_read) > 0 {
364 let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
365 let iov = &mut iovs[nr_skip];
366
367 let mut data = [
368 &[iovec {
369 iov_base: (iov.iov_base as usize + offset) as *mut c_void,
370 iov_len: iov.iov_len - offset,
371 }],
372 &iovs[(nr_skip + 1)..],
373 ]
374 .concat();
375
376 let res = self.recv_into_iovec(&mut data);
377 match res {
378 Ok((0, _)) => return Ok((data_read, rfds)),
379 Ok((n, fds)) => {
380 if data_read == 0 {
381 rfds = fds;
382 }
383 data_read += n;
384 }
385 Err(e) => match e {
386 Error::SocketRetry(_) => {}
387 _ => return Err(e),
388 },
389 }
390 }
391 Ok((data_read, rfds))
392 }
393
394 /// Reads bytes from the socket into a new buffer with optional attached
395 /// file descriptors. Received file descriptors are set close-on-exec.
396 ///
397 /// # Return:
398 /// * - (number of bytes received, buf, [received fds]) on success.
399 /// * - SocketRetry: temporary error caused by signals or short of resources.
400 /// * - SocketBroken: the underline socket is broken.
401 /// * - SocketError: other socket related errors.
recv_into_buf( &mut self, buf_size: usize, ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)>402 pub fn recv_into_buf(
403 &mut self,
404 buf_size: usize,
405 ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
406 let mut buf = vec![0u8; buf_size];
407 let (bytes, rfds) = {
408 let mut iovs = [iovec {
409 iov_base: buf.as_mut_ptr() as *mut c_void,
410 iov_len: buf_size,
411 }];
412 self.recv_into_iovec(&mut iovs)?
413 };
414 Ok((bytes, buf, rfds))
415 }
416
417 /// Receive a header-only message with optional attached file descriptors.
418 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
419 /// accepted and all other file descriptor will be discard silently.
420 ///
421 /// # Return:
422 /// * - (message header, [received fds]) on success.
423 /// * - SocketRetry: temporary error caused by signals or short of resources.
424 /// * - SocketBroken: the underline socket is broken.
425 /// * - SocketError: other socket related errors.
426 /// * - PartialMessage: received a partial message.
427 /// * - InvalidMessage: received a invalid message.
recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)>428 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
429 let mut hdr = VhostUserMsgHeader::default();
430 let mut iovs = [iovec {
431 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
432 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
433 }];
434 let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
435
436 if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
437 return Err(Error::PartialMessage);
438 } else if !hdr.is_valid() {
439 return Err(Error::InvalidMessage);
440 }
441
442 Ok((hdr, rfds))
443 }
444
445 /// Receive a message with optional attached file descriptors.
446 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
447 /// accepted and all other file descriptor will be discard silently.
448 ///
449 /// # Return:
450 /// * - (message header, message body, [received fds]) on success.
451 /// * - SocketRetry: temporary error caused by signals or short of resources.
452 /// * - SocketBroken: the underline socket is broken.
453 /// * - SocketError: other socket related errors.
454 /// * - PartialMessage: received a partial message.
455 /// * - InvalidMessage: received a invalid message.
recv_body<T: Sized + Default + VhostUserMsgValidator>( &mut self, ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)>456 pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
457 &mut self,
458 ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
459 let mut hdr = VhostUserMsgHeader::default();
460 let mut body: T = Default::default();
461 let mut iovs = [
462 iovec {
463 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
464 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
465 },
466 iovec {
467 iov_base: (&mut body as *mut T) as *mut c_void,
468 iov_len: mem::size_of::<T>(),
469 },
470 ];
471 let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
472
473 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
474 if bytes != total {
475 return Err(Error::PartialMessage);
476 } else if !hdr.is_valid() || !body.is_valid() {
477 return Err(Error::InvalidMessage);
478 }
479
480 Ok((hdr, body, rfds))
481 }
482
483 /// Receive a message with header and optional content. Callers need to
484 /// pre-allocate a big enough buffer to receive the message body and
485 /// optional payload. If there are attached file descriptor associated
486 /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
487 /// will be accepted and all other file descriptor will be discard
488 /// silently.
489 ///
490 /// # Return:
491 /// * - (message header, message size, [received fds]) on success.
492 /// * - SocketRetry: temporary error caused by signals or short of resources.
493 /// * - SocketBroken: the underline socket is broken.
494 /// * - SocketError: other socket related errors.
495 /// * - PartialMessage: received a partial message.
496 /// * - InvalidMessage: received a invalid message.
recv_body_into_buf( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)>497 pub fn recv_body_into_buf(
498 &mut self,
499 buf: &mut [u8],
500 ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
501 let mut hdr = VhostUserMsgHeader::default();
502 let mut iovs = [
503 iovec {
504 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
505 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
506 },
507 iovec {
508 iov_base: buf.as_mut_ptr() as *mut c_void,
509 iov_len: buf.len(),
510 },
511 ];
512 let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
513
514 if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
515 return Err(Error::PartialMessage);
516 } else if !hdr.is_valid() {
517 return Err(Error::InvalidMessage);
518 }
519
520 Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
521 }
522
523 /// Receive a message with optional payload and attached file descriptors.
524 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
525 /// accepted and all other file descriptor will be discard silently.
526 ///
527 /// # Return:
528 /// * - (message header, message body, size of payload, [received fds]) on success.
529 /// * - SocketRetry: temporary error caused by signals or short of resources.
530 /// * - SocketBroken: the underline socket is broken.
531 /// * - SocketError: other socket related errors.
532 /// * - PartialMessage: received a partial message.
533 /// * - InvalidMessage: received a invalid message.
534 #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( &mut self, buf: &mut [u8], ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)>535 pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
536 &mut self,
537 buf: &mut [u8],
538 ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
539 let mut hdr = VhostUserMsgHeader::default();
540 let mut body: T = Default::default();
541 let mut iovs = [
542 iovec {
543 iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
544 iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
545 },
546 iovec {
547 iov_base: (&mut body as *mut T) as *mut c_void,
548 iov_len: mem::size_of::<T>(),
549 },
550 iovec {
551 iov_base: buf.as_mut_ptr() as *mut c_void,
552 iov_len: buf.len(),
553 },
554 ];
555 let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
556
557 let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
558 if bytes < total {
559 return Err(Error::PartialMessage);
560 } else if !hdr.is_valid() || !body.is_valid() {
561 return Err(Error::InvalidMessage);
562 }
563
564 Ok((hdr, body, bytes - total, rfds))
565 }
566
567 /// Close all raw file descriptors.
close_rfds(rfds: Option<Vec<RawFd>>)568 pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
569 if let Some(fds) = rfds {
570 for fd in fds {
571 // safe because the rawfds are valid and we don't care about the result.
572 let _ = unsafe { libc::close(fd) };
573 }
574 }
575 }
576 }
577
578 impl<T: Req> AsRawFd for Endpoint<T> {
as_raw_fd(&self) -> RawFd579 fn as_raw_fd(&self) -> RawFd {
580 self.sock.as_raw_fd()
581 }
582 }
583
584 // Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
585 // For example:
586 // let iov_lens = vec![4, 4, 5];
587 // let size = 6;
588 // assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize)589 fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
590 let mut size = skip_size;
591 let mut nr_skip = 0;
592
593 for len in iov_lens {
594 if size >= *len {
595 size -= *len;
596 nr_skip += 1;
597 } else {
598 break;
599 }
600 }
601 (nr_skip, size)
602 }
603
604 #[cfg(test)]
605 mod tests {
606 use super::*;
607 use std::fs::File;
608 use std::io::{Read, Seek, SeekFrom, Write};
609 use std::os::unix::io::FromRawFd;
610 use tempfile::{tempfile, Builder, TempDir};
611
temp_dir() -> TempDir612 fn temp_dir() -> TempDir {
613 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
614 }
615
616 #[test]
create_listener()617 fn create_listener() {
618 let dir = temp_dir();
619 let mut path = dir.path().to_owned();
620 path.push("sock");
621 let listener = Listener::new(&path, true).unwrap();
622
623 assert!(listener.as_raw_fd() > 0);
624 }
625
626 #[test]
accept_connection()627 fn accept_connection() {
628 let dir = temp_dir();
629 let mut path = dir.path().to_owned();
630 path.push("sock");
631 let listener = Listener::new(&path, true).unwrap();
632 listener.set_nonblocking(true).unwrap();
633
634 // accept on a fd without incoming connection
635 let conn = listener.accept().unwrap();
636 assert!(conn.is_none());
637 }
638
639 #[test]
send_data()640 fn send_data() {
641 let dir = temp_dir();
642 let mut path = dir.path().to_owned();
643 path.push("sock");
644 let listener = Listener::new(&path, true).unwrap();
645 listener.set_nonblocking(true).unwrap();
646 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
647 let sock = listener.accept().unwrap().unwrap();
648 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
649
650 let buf1 = vec![0x1, 0x2, 0x3, 0x4];
651 let mut len = master.send_slice(&buf1[..], None).unwrap();
652 assert_eq!(len, 4);
653 let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
654 assert_eq!(bytes, 4);
655 assert_eq!(&buf1[..], &buf2[..bytes]);
656
657 len = master.send_slice(&buf1[..], None).unwrap();
658 assert_eq!(len, 4);
659 let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
660 assert_eq!(bytes, 2);
661 assert_eq!(&buf1[..2], &buf2[..]);
662 let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
663 assert_eq!(bytes, 2);
664 assert_eq!(&buf1[2..], &buf2[..]);
665 }
666
667 #[test]
send_fd()668 fn send_fd() {
669 let dir = temp_dir();
670 let mut path = dir.path().to_owned();
671 path.push("sock");
672 let listener = Listener::new(&path, true).unwrap();
673 listener.set_nonblocking(true).unwrap();
674 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
675 let sock = listener.accept().unwrap().unwrap();
676 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
677
678 let mut fd = tempfile().unwrap();
679 write!(fd, "test").unwrap();
680
681 // Normal case for sending/receiving file descriptors
682 let buf1 = vec![0x1, 0x2, 0x3, 0x4];
683 let len = master
684 .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
685 .unwrap();
686 assert_eq!(len, 4);
687
688 let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
689 assert_eq!(bytes, 4);
690 assert_eq!(&buf1[..], &buf2[..]);
691 assert!(rfds.is_some());
692 let fds = rfds.unwrap();
693 {
694 assert_eq!(fds.len(), 1);
695 let mut file = unsafe { File::from_raw_fd(fds[0]) };
696 let mut content = String::new();
697 file.seek(SeekFrom::Start(0)).unwrap();
698 file.read_to_string(&mut content).unwrap();
699 assert_eq!(content, "test");
700 }
701
702 // Following communication pattern should work:
703 // Sending side: data(header, body) with fds
704 // Receiving side: data(header) with fds, data(body)
705 let len = master
706 .send_slice(
707 &buf1[..],
708 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
709 )
710 .unwrap();
711 assert_eq!(len, 4);
712
713 let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
714 assert_eq!(bytes, 2);
715 assert_eq!(&buf1[..2], &buf2[..]);
716 assert!(rfds.is_some());
717 let fds = rfds.unwrap();
718 {
719 assert_eq!(fds.len(), 3);
720 let mut file = unsafe { File::from_raw_fd(fds[1]) };
721 let mut content = String::new();
722 file.seek(SeekFrom::Start(0)).unwrap();
723 file.read_to_string(&mut content).unwrap();
724 assert_eq!(content, "test");
725 }
726 let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
727 assert_eq!(bytes, 2);
728 assert_eq!(&buf1[2..], &buf2[..]);
729 assert!(rfds.is_none());
730
731 // Following communication pattern should not work:
732 // Sending side: data(header, body) with fds
733 // Receiving side: data(header), data(body) with fds
734 let len = master
735 .send_slice(
736 &buf1[..],
737 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
738 )
739 .unwrap();
740 assert_eq!(len, 4);
741
742 let (bytes, buf4) = slave.recv_data(2).unwrap();
743 assert_eq!(bytes, 2);
744 assert_eq!(&buf1[..2], &buf4[..]);
745 let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
746 assert_eq!(bytes, 2);
747 assert_eq!(&buf1[2..], &buf2[..]);
748 assert!(rfds.is_none());
749
750 // Following communication pattern should work:
751 // Sending side: data, data with fds
752 // Receiving side: data, data with fds
753 let len = master.send_slice(&buf1[..], None).unwrap();
754 assert_eq!(len, 4);
755 let len = master
756 .send_slice(
757 &buf1[..],
758 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
759 )
760 .unwrap();
761 assert_eq!(len, 4);
762
763 let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
764 assert_eq!(bytes, 4);
765 assert_eq!(&buf1[..], &buf2[..]);
766 assert!(rfds.is_none());
767
768 let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
769 assert_eq!(bytes, 2);
770 assert_eq!(&buf1[..2], &buf2[..]);
771 assert!(rfds.is_some());
772 let fds = rfds.unwrap();
773 {
774 assert_eq!(fds.len(), 3);
775 let mut file = unsafe { File::from_raw_fd(fds[1]) };
776 let mut content = String::new();
777 file.seek(SeekFrom::Start(0)).unwrap();
778 file.read_to_string(&mut content).unwrap();
779 assert_eq!(content, "test");
780 }
781 let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
782 assert_eq!(bytes, 2);
783 assert_eq!(&buf1[2..], &buf2[..]);
784 assert!(rfds.is_none());
785
786 // Following communication pattern should not work:
787 // Sending side: data1, data2 with fds
788 // Receiving side: data + partial of data2, left of data2 with fds
789 let len = master.send_slice(&buf1[..], None).unwrap();
790 assert_eq!(len, 4);
791 let len = master
792 .send_slice(
793 &buf1[..],
794 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
795 )
796 .unwrap();
797 assert_eq!(len, 4);
798
799 let (bytes, _) = slave.recv_data(5).unwrap();
800 assert_eq!(bytes, 5);
801
802 let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
803 assert_eq!(bytes, 3);
804 assert!(rfds.is_none());
805
806 // If the target fd array is too small, extra file descriptors will get lost.
807 let len = master
808 .send_slice(
809 &buf1[..],
810 Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
811 )
812 .unwrap();
813 assert_eq!(len, 4);
814
815 let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
816 assert_eq!(bytes, 4);
817 assert!(rfds.is_some());
818
819 Endpoint::<MasterReq>::close_rfds(rfds);
820 Endpoint::<MasterReq>::close_rfds(None);
821 }
822
823 #[test]
send_recv()824 fn send_recv() {
825 let dir = temp_dir();
826 let mut path = dir.path().to_owned();
827 path.push("sock");
828 let listener = Listener::new(&path, true).unwrap();
829 listener.set_nonblocking(true).unwrap();
830 let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
831 let sock = listener.accept().unwrap().unwrap();
832 let mut slave = Endpoint::<MasterReq>::from_stream(sock);
833
834 let mut hdr1 =
835 VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
836 hdr1.set_need_reply(true);
837 let features1 = 0x1u64;
838 master.send_message(&hdr1, &features1, None).unwrap();
839
840 let mut features2 = 0u64;
841 let slice = unsafe {
842 slice::from_raw_parts_mut(
843 (&mut features2 as *mut u64) as *mut u8,
844 mem::size_of::<u64>(),
845 )
846 };
847 let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
848 assert_eq!(hdr1, hdr2);
849 assert_eq!(bytes, 8);
850 assert_eq!(features1, features2);
851 assert!(rfds.is_none());
852
853 master.send_header(&hdr1, None).unwrap();
854 let (hdr2, rfds) = slave.recv_header().unwrap();
855 assert_eq!(hdr1, hdr2);
856 assert!(rfds.is_none());
857 }
858 }
859