• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Common data structures for listener and connection.
5 
6 use std::fs::File;
7 use std::io::IoSliceMut;
8 use std::mem;
9 
10 use base::AsRawDescriptor;
11 use base::RawDescriptor;
12 use zerocopy::FromBytes;
13 use zerocopy::Immutable;
14 use zerocopy::IntoBytes;
15 
16 use crate::connection::Req;
17 use crate::message::FrontendReq;
18 use crate::message::*;
19 use crate::sys::PlatformConnection;
20 use crate::Error;
21 use crate::Result;
22 
23 /// Listener for accepting connections.
24 pub trait Listener: Sized {
25     /// Accept an incoming connection.
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>26     fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
27 
28     /// Change blocking status on the listener.
set_nonblocking(&self, block: bool) -> Result<()>29     fn set_nonblocking(&self, block: bool) -> Result<()>;
30 }
31 
32 // Advance the internal cursor of the slices.
33 // This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize)34 fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
35     use std::mem::take;
36 
37     let mut idx = 0;
38     for b in bufs.iter() {
39         if count < b.len() {
40             break;
41         }
42         count -= b.len();
43         idx += 1;
44     }
45     *bufs = &mut take(bufs)[idx..];
46     if !bufs.is_empty() {
47         let slice = take(&mut bufs[0]);
48         let (_, remaining) = slice.split_at_mut(count);
49         bufs[0] = remaining;
50     }
51 }
52 
53 /// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
54 /// vhost-user message headers and bodies.
55 ///
56 /// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
57 /// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
58 pub struct Connection<R: Req>(
59     pub(crate) PlatformConnection,
60     pub(crate) std::marker::PhantomData<R>,
61     // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
62     // concurrently.
63     pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
64 );
65 
66 impl<R: Req> Connection<R> {
67     /// Sends a header-only message with optional attached file descriptors.
send_header_only_message( &self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawDescriptor]>, ) -> Result<()>68     pub fn send_header_only_message(
69         &self,
70         hdr: &VhostUserMsgHeader<R>,
71         fds: Option<&[RawDescriptor]>,
72     ) -> Result<()> {
73         self.0
74             .send_message(hdr.into_raw().as_bytes(), &[], &[], fds)
75     }
76 
77     /// Send a message with header and body. Optional file descriptors may be attached to
78     /// the message.
send_message<T: IntoBytes + Immutable>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawDescriptor]>, ) -> Result<()>79     pub fn send_message<T: IntoBytes + Immutable>(
80         &self,
81         hdr: &VhostUserMsgHeader<R>,
82         body: &T,
83         fds: Option<&[RawDescriptor]>,
84     ) -> Result<()> {
85         self.0
86             .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), &[], fds)
87     }
88 
89     /// Send a message with header and body. `payload` is appended to the end of the body. Optional
90     /// file descriptors may also be attached to the message.
send_message_with_payload<T: IntoBytes + Immutable>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>91     pub fn send_message_with_payload<T: IntoBytes + Immutable>(
92         &self,
93         hdr: &VhostUserMsgHeader<R>,
94         body: &T,
95         payload: &[u8],
96         fds: Option<&[RawDescriptor]>,
97     ) -> Result<()> {
98         self.0
99             .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), payload, fds)
100     }
101 
102     /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
103     /// loop until all data has been transfered and errors if EOF is reached before then.
104     ///
105     /// # Return:
106     /// * - received fds on success
107     /// * - `Disconnect` - client is closed
108     ///
109     /// # TODO
110     /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
111     /// cursor needs to be moved by `advance_slices_mut()`.
112     /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
113     /// <https://github.com/rust-lang/rust/issues/62726>.
recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>>114     fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
115         let mut first_read = true;
116         let mut rfds = Vec::new();
117 
118         // Guarantee that `bufs` becomes empty if it doesn't contain any data.
119         advance_slices_mut(&mut bufs, 0);
120 
121         while !bufs.is_empty() {
122             let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
123             let res = self.0.recv_into_bufs(&mut slices, true);
124             match res {
125                 Ok((0, _)) => return Err(Error::PartialMessage),
126                 Ok((n, fds)) => {
127                     if first_read {
128                         first_read = false;
129                         if let Some(fds) = fds {
130                             rfds = fds;
131                         }
132                     }
133                     advance_slices_mut(&mut bufs, n);
134                 }
135                 Err(e) => match e {
136                     Error::SocketRetry(_) => {}
137                     _ => return Err(e),
138                 },
139             }
140         }
141         Ok(rfds)
142     }
143 
144     /// Receive message header
145     ///
146     /// Errors if the header is invalid.
147     ///
148     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
149     /// other file descriptor will be discard silently.
recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)>150     pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
151         let mut hdr_raw = [0u32; 3];
152         let files = self.recv_into_bufs_all(&mut [hdr_raw.as_mut_bytes()])?;
153         let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
154         if !hdr.is_valid() {
155             return Err(Error::InvalidMessage);
156         }
157         Ok((hdr, files))
158     }
159 
160     /// Receive the body following the header `hdr`.
recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>>161     pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>> {
162         // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
163         // works as expected.
164         let mut body = vec![0; hdr.get_size().try_into().unwrap()];
165         let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
166         if !files.is_empty() {
167             return Err(Error::InvalidMessage);
168         }
169         Ok(body)
170     }
171 
172     /// Receive a message header and body.
173     ///
174     /// Errors if the header or body is invalid.
175     ///
176     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
177     /// accepted and all other file descriptor will be discard silently.
recv_message<T: IntoBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)>178     pub fn recv_message<T: IntoBytes + FromBytes + VhostUserMsgValidator>(
179         &self,
180     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
181         let mut hdr_raw = [0u32; 3];
182         let mut body = T::new_zeroed();
183         let mut slices = [hdr_raw.as_mut_bytes(), body.as_mut_bytes()];
184         let files = self.recv_into_bufs_all(&mut slices)?;
185 
186         let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
187         if !hdr.is_valid() || !body.is_valid() {
188             return Err(Error::InvalidMessage);
189         }
190 
191         Ok((hdr, body, files))
192     }
193 
194     /// Receive a message header and body, where the body includes a variable length payload at the
195     /// end.
196     ///
197     /// Errors if the header or body is invalid.
198     ///
199     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
200     /// other file descriptor will be discard silently.
recv_message_with_payload<T: IntoBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)>201     pub fn recv_message_with_payload<T: IntoBytes + FromBytes + VhostUserMsgValidator>(
202         &self,
203     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)> {
204         let (hdr, files) = self.recv_header()?;
205 
206         let mut body = T::new_zeroed();
207         let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
208         let mut buf: Vec<u8> = vec![0; payload_size];
209         let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
210         let more_files = self.recv_into_bufs_all(&mut slices)?;
211         if !body.is_valid() || !more_files.is_empty() {
212             return Err(Error::InvalidMessage);
213         }
214 
215         Ok((hdr, body, buf, files))
216     }
217 }
218 
219 impl<R: Req> AsRawDescriptor for Connection<R> {
as_raw_descriptor(&self) -> RawDescriptor220     fn as_raw_descriptor(&self) -> RawDescriptor {
221         self.0.as_raw_descriptor()
222     }
223 }
224 
225 #[cfg(test)]
226 pub(crate) mod tests {
227     use std::io::Read;
228     use std::io::Seek;
229     use std::io::SeekFrom;
230     use std::io::Write;
231 
232     use tempfile::tempfile;
233 
234     use super::*;
235     use crate::message::VhostUserEmptyMessage;
236     use crate::message::VhostUserU64;
237 
238     #[test]
send_header_only()239     fn send_header_only() {
240         let (client_connection, server_connection) = Connection::pair().unwrap();
241         let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
242         client_connection
243             .send_header_only_message(&hdr1, None)
244             .unwrap();
245         let (hdr2, _, files) = server_connection
246             .recv_message::<VhostUserEmptyMessage>()
247             .unwrap();
248         assert_eq!(hdr1, hdr2);
249         assert!(files.is_empty());
250     }
251 
252     #[test]
send_data()253     fn send_data() {
254         let (client_connection, server_connection) = Connection::pair().unwrap();
255         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
256         client_connection
257             .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
258             .unwrap();
259         let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
260         assert_eq!(hdr1, hdr2);
261         let value = body.value;
262         assert_eq!(value, 0xf00dbeefdeadf00d);
263         assert!(files.is_empty());
264     }
265 
266     #[test]
send_fd()267     fn send_fd() {
268         let (client_connection, server_connection) = Connection::pair().unwrap();
269 
270         let mut fd = tempfile().unwrap();
271         write!(fd, "test").unwrap();
272 
273         // Normal case for sending/receiving file descriptors
274         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
275         client_connection
276             .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
277             .unwrap();
278 
279         let (hdr2, _, files) = server_connection
280             .recv_message::<VhostUserEmptyMessage>()
281             .unwrap();
282         assert_eq!(hdr1, hdr2);
283         assert_eq!(files.len(), 1);
284         let mut file = &files[0];
285         let mut content = String::new();
286         file.seek(SeekFrom::Start(0)).unwrap();
287         file.read_to_string(&mut content).unwrap();
288         assert_eq!(content, "test");
289     }
290 
291     #[test]
test_advance_slices_mut()292     fn test_advance_slices_mut() {
293         // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
294         let mut buf1 = [1; 8];
295         let mut buf2 = [2; 16];
296         let mut buf3 = [3; 8];
297         let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
298         advance_slices_mut(&mut bufs, 10);
299         assert_eq!(bufs[0], [2; 14].as_ref());
300         assert_eq!(bufs[1], [3; 8].as_ref());
301     }
302 }
303