• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Common data structures for listener and connection.
5 
6 use std::fs::File;
7 use std::io::IoSliceMut;
8 use std::mem;
9 use std::path::Path;
10 
11 use base::AsRawDescriptor;
12 use base::RawDescriptor;
13 use zerocopy::AsBytes;
14 use zerocopy::FromBytes;
15 
16 use crate::connection::Req;
17 use crate::message::FrontendReq;
18 use crate::message::*;
19 use crate::sys::PlatformConnection;
20 use crate::Error;
21 use crate::Result;
22 use crate::SystemStream;
23 
24 /// Listener for accepting connections.
25 pub trait Listener: Sized {
26     /// Accept an incoming connection.
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>27     fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
28 
29     /// Change blocking status on the listener.
set_nonblocking(&self, block: bool) -> Result<()>30     fn set_nonblocking(&self, block: bool) -> Result<()>;
31 }
32 
33 // Advance the internal cursor of the slices.
34 // This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize)35 fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
36     use std::mem::take;
37 
38     let mut idx = 0;
39     for b in bufs.iter() {
40         if count < b.len() {
41             break;
42         }
43         count -= b.len();
44         idx += 1;
45     }
46     *bufs = &mut take(bufs)[idx..];
47     if !bufs.is_empty() {
48         let slice = take(&mut bufs[0]);
49         let (_, remaining) = slice.split_at_mut(count);
50         bufs[0] = remaining;
51     }
52 }
53 
54 /// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
55 /// vhost-user message headers and bodies.
56 ///
57 /// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
58 /// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
59 pub struct Connection<R: Req>(
60     pub(crate) PlatformConnection,
61     std::marker::PhantomData<R>,
62     // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
63     // concurrently.
64     std::marker::PhantomData<std::cell::Cell<()>>,
65 );
66 
67 impl<R: Req> From<SystemStream> for Connection<R> {
from(sock: SystemStream) -> Self68     fn from(sock: SystemStream) -> Self {
69         Self(
70             PlatformConnection::from(sock),
71             std::marker::PhantomData,
72             std::marker::PhantomData,
73         )
74     }
75 }
76 
77 impl<R: Req> Connection<R> {
78     /// Create a new stream by connecting to server at `path`.
connect<P: AsRef<Path>>(path: P) -> Result<Self>79     pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
80         Ok(Self(
81             PlatformConnection::connect(path)?,
82             std::marker::PhantomData,
83             std::marker::PhantomData,
84         ))
85     }
86 
87     /// Sends a header-only message with optional attached file descriptors.
send_header_only_message( &self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawDescriptor]>, ) -> Result<()>88     pub fn send_header_only_message(
89         &self,
90         hdr: &VhostUserMsgHeader<R>,
91         fds: Option<&[RawDescriptor]>,
92     ) -> Result<()> {
93         self.0.send_message(hdr.as_bytes(), &[], &[], fds)
94     }
95 
96     /// Send a message with header and body. Optional file descriptors may be attached to
97     /// the message.
send_message<T: AsBytes>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawDescriptor]>, ) -> Result<()>98     pub fn send_message<T: AsBytes>(
99         &self,
100         hdr: &VhostUserMsgHeader<R>,
101         body: &T,
102         fds: Option<&[RawDescriptor]>,
103     ) -> Result<()> {
104         self.0
105             .send_message(hdr.as_bytes(), body.as_bytes(), &[], fds)
106     }
107 
108     /// Send a message with header and body. `payload` is appended to the end of the body. Optional
109     /// file descriptors may also be attached to the message.
send_message_with_payload<T: Sized + AsBytes>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>110     pub fn send_message_with_payload<T: Sized + AsBytes>(
111         &self,
112         hdr: &VhostUserMsgHeader<R>,
113         body: &T,
114         payload: &[u8],
115         fds: Option<&[RawDescriptor]>,
116     ) -> Result<()> {
117         self.0
118             .send_message(hdr.as_bytes(), body.as_bytes(), payload, fds)
119     }
120 
121     /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
122     /// loop until all data has been transfered and errors if EOF is reached before then.
123     ///
124     /// # Return:
125     /// * - received fds on success
126     /// * - `Disconnect` - client is closed
127     ///
128     /// # TODO
129     /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
130     /// cursor needs to be moved by `advance_slices_mut()`.
131     /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
132     /// <https://github.com/rust-lang/rust/issues/62726>.
recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>>133     fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
134         let mut first_read = true;
135         let mut rfds = Vec::new();
136 
137         // Guarantee that `bufs` becomes empty if it doesn't contain any data.
138         advance_slices_mut(&mut bufs, 0);
139 
140         while !bufs.is_empty() {
141             let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
142             let res = self.0.recv_into_bufs(&mut slices, true);
143             match res {
144                 Ok((0, _)) => return Err(Error::PartialMessage),
145                 Ok((n, fds)) => {
146                     if first_read {
147                         first_read = false;
148                         if let Some(fds) = fds {
149                             rfds = fds;
150                         }
151                     }
152                     advance_slices_mut(&mut bufs, n);
153                 }
154                 Err(e) => match e {
155                     Error::SocketRetry(_) => {}
156                     _ => return Err(e),
157                 },
158             }
159         }
160         Ok(rfds)
161     }
162 
163     /// Receive message header
164     ///
165     /// Errors if the header is invalid.
166     ///
167     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
168     /// other file descriptor will be discard silently.
recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)>169     pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
170         let mut hdr = VhostUserMsgHeader::default();
171         let files = self.recv_into_bufs_all(&mut [hdr.as_bytes_mut()])?;
172         if !hdr.is_valid() {
173             return Err(Error::InvalidMessage);
174         }
175         Ok((hdr, files))
176     }
177 
178     /// Receive the body following the header `hdr`.
recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>>179     pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>> {
180         // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
181         // works as expected.
182         let mut body = vec![0; hdr.get_size().try_into().unwrap()];
183         let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
184         if !files.is_empty() {
185             return Err(Error::InvalidMessage);
186         }
187         Ok(body)
188     }
189 
190     /// Receive a message header and body.
191     ///
192     /// Errors if the header or body is invalid.
193     ///
194     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
195     /// accepted and all other file descriptor will be discard silently.
recv_message<T: AsBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)>196     pub fn recv_message<T: AsBytes + FromBytes + VhostUserMsgValidator>(
197         &self,
198     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
199         let mut hdr = VhostUserMsgHeader::default();
200         let mut body = T::new_zeroed();
201         let mut slices = [hdr.as_bytes_mut(), body.as_bytes_mut()];
202         let files = self.recv_into_bufs_all(&mut slices)?;
203 
204         if !hdr.is_valid() || !body.is_valid() {
205             return Err(Error::InvalidMessage);
206         }
207 
208         Ok((hdr, body, files))
209     }
210 
211     /// Receive a message header and body, where the body includes a variable length payload at the
212     /// end.
213     ///
214     /// Errors if the header or body is invalid.
215     ///
216     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
217     /// other file descriptor will be discard silently.
recv_message_with_payload<T: AsBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)>218     pub fn recv_message_with_payload<T: AsBytes + FromBytes + VhostUserMsgValidator>(
219         &self,
220     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)> {
221         let (hdr, files) = self.recv_header()?;
222 
223         let mut body = T::new_zeroed();
224         let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
225         let mut buf: Vec<u8> = vec![0; payload_size];
226         let mut slices = [body.as_bytes_mut(), buf.as_bytes_mut()];
227         let more_files = self.recv_into_bufs_all(&mut slices)?;
228         if !body.is_valid() || !more_files.is_empty() {
229             return Err(Error::InvalidMessage);
230         }
231 
232         Ok((hdr, body, buf, files))
233     }
234 }
235 
236 impl<R: Req> AsRawDescriptor for Connection<R> {
as_raw_descriptor(&self) -> RawDescriptor237     fn as_raw_descriptor(&self) -> RawDescriptor {
238         self.0.as_raw_descriptor()
239     }
240 }
241 
242 #[cfg(test)]
243 pub(crate) mod tests {
244     use std::io::Read;
245     use std::io::Seek;
246     use std::io::SeekFrom;
247     use std::io::Write;
248 
249     use tempfile::tempfile;
250 
251     use super::*;
252     use crate::message::VhostUserEmptyMessage;
253     use crate::message::VhostUserU64;
254     use crate::tests::create_connection_pair;
255 
256     #[test]
send_header_only()257     fn send_header_only() {
258         let (client_connection, server_connection) = create_connection_pair();
259         let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
260         client_connection
261             .send_header_only_message(&hdr1, None)
262             .unwrap();
263         let (hdr2, _, files) = server_connection
264             .recv_message::<VhostUserEmptyMessage>()
265             .unwrap();
266         assert_eq!(hdr1, hdr2);
267         assert!(files.is_empty());
268     }
269 
270     #[test]
send_data()271     fn send_data() {
272         let (client_connection, server_connection) = create_connection_pair();
273         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
274         client_connection
275             .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
276             .unwrap();
277         let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
278         assert_eq!(hdr1, hdr2);
279         let value = body.value;
280         assert_eq!(value, 0xf00dbeefdeadf00d);
281         assert!(files.is_empty());
282     }
283 
284     #[test]
send_fd()285     fn send_fd() {
286         let (client_connection, server_connection) = create_connection_pair();
287 
288         let mut fd = tempfile().unwrap();
289         write!(fd, "test").unwrap();
290 
291         // Normal case for sending/receiving file descriptors
292         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
293         client_connection
294             .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
295             .unwrap();
296 
297         let (hdr2, _, files) = server_connection
298             .recv_message::<VhostUserEmptyMessage>()
299             .unwrap();
300         assert_eq!(hdr1, hdr2);
301         assert_eq!(files.len(), 1);
302         let mut file = &files[0];
303         let mut content = String::new();
304         file.seek(SeekFrom::Start(0)).unwrap();
305         file.read_to_string(&mut content).unwrap();
306         assert_eq!(content, "test");
307     }
308 
309     #[test]
test_advance_slices_mut()310     fn test_advance_slices_mut() {
311         // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
312         let mut buf1 = [1; 8];
313         let mut buf2 = [2; 16];
314         let mut buf3 = [3; 8];
315         let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
316         advance_slices_mut(&mut bufs, 10);
317         assert_eq!(bufs[0], [2; 14].as_ref());
318         assert_eq!(bufs[1], [3; 8].as_ref());
319     }
320 }
321