1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 //! Common data structures for listener and connection.
5
6 use std::fs::File;
7 use std::io::IoSliceMut;
8 use std::mem;
9
10 use base::AsRawDescriptor;
11 use base::RawDescriptor;
12 use zerocopy::FromBytes;
13 use zerocopy::Immutable;
14 use zerocopy::IntoBytes;
15
16 use crate::connection::Req;
17 use crate::message::FrontendReq;
18 use crate::message::*;
19 use crate::sys::PlatformConnection;
20 use crate::Error;
21 use crate::Result;
22
23 /// Listener for accepting connections.
24 pub trait Listener: Sized {
25 /// Accept an incoming connection.
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>26 fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
27
28 /// Change blocking status on the listener.
set_nonblocking(&self, block: bool) -> Result<()>29 fn set_nonblocking(&self, block: bool) -> Result<()>;
30 }
31
32 // Advance the internal cursor of the slices.
33 // This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize)34 fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
35 use std::mem::take;
36
37 let mut idx = 0;
38 for b in bufs.iter() {
39 if count < b.len() {
40 break;
41 }
42 count -= b.len();
43 idx += 1;
44 }
45 *bufs = &mut take(bufs)[idx..];
46 if !bufs.is_empty() {
47 let slice = take(&mut bufs[0]);
48 let (_, remaining) = slice.split_at_mut(count);
49 bufs[0] = remaining;
50 }
51 }
52
53 /// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
54 /// vhost-user message headers and bodies.
55 ///
56 /// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
57 /// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
58 pub struct Connection<R: Req>(
59 pub(crate) PlatformConnection,
60 pub(crate) std::marker::PhantomData<R>,
61 // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
62 // concurrently.
63 pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
64 );
65
66 impl<R: Req> Connection<R> {
67 /// Sends a header-only message with optional attached file descriptors.
send_header_only_message( &self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawDescriptor]>, ) -> Result<()>68 pub fn send_header_only_message(
69 &self,
70 hdr: &VhostUserMsgHeader<R>,
71 fds: Option<&[RawDescriptor]>,
72 ) -> Result<()> {
73 self.0
74 .send_message(hdr.into_raw().as_bytes(), &[], &[], fds)
75 }
76
77 /// Send a message with header and body. Optional file descriptors may be attached to
78 /// the message.
send_message<T: IntoBytes + Immutable>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawDescriptor]>, ) -> Result<()>79 pub fn send_message<T: IntoBytes + Immutable>(
80 &self,
81 hdr: &VhostUserMsgHeader<R>,
82 body: &T,
83 fds: Option<&[RawDescriptor]>,
84 ) -> Result<()> {
85 self.0
86 .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), &[], fds)
87 }
88
89 /// Send a message with header and body. `payload` is appended to the end of the body. Optional
90 /// file descriptors may also be attached to the message.
send_message_with_payload<T: IntoBytes + Immutable>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>91 pub fn send_message_with_payload<T: IntoBytes + Immutable>(
92 &self,
93 hdr: &VhostUserMsgHeader<R>,
94 body: &T,
95 payload: &[u8],
96 fds: Option<&[RawDescriptor]>,
97 ) -> Result<()> {
98 self.0
99 .send_message(hdr.into_raw().as_bytes(), body.as_bytes(), payload, fds)
100 }
101
102 /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
103 /// loop until all data has been transfered and errors if EOF is reached before then.
104 ///
105 /// # Return:
106 /// * - received fds on success
107 /// * - `Disconnect` - client is closed
108 ///
109 /// # TODO
110 /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
111 /// cursor needs to be moved by `advance_slices_mut()`.
112 /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
113 /// <https://github.com/rust-lang/rust/issues/62726>.
recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>>114 fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
115 let mut first_read = true;
116 let mut rfds = Vec::new();
117
118 // Guarantee that `bufs` becomes empty if it doesn't contain any data.
119 advance_slices_mut(&mut bufs, 0);
120
121 while !bufs.is_empty() {
122 let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
123 let res = self.0.recv_into_bufs(&mut slices, true);
124 match res {
125 Ok((0, _)) => return Err(Error::PartialMessage),
126 Ok((n, fds)) => {
127 if first_read {
128 first_read = false;
129 if let Some(fds) = fds {
130 rfds = fds;
131 }
132 }
133 advance_slices_mut(&mut bufs, n);
134 }
135 Err(e) => match e {
136 Error::SocketRetry(_) => {}
137 _ => return Err(e),
138 },
139 }
140 }
141 Ok(rfds)
142 }
143
144 /// Receive message header
145 ///
146 /// Errors if the header is invalid.
147 ///
148 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
149 /// other file descriptor will be discard silently.
recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)>150 pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
151 let mut hdr_raw = [0u32; 3];
152 let files = self.recv_into_bufs_all(&mut [hdr_raw.as_mut_bytes()])?;
153 let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
154 if !hdr.is_valid() {
155 return Err(Error::InvalidMessage);
156 }
157 Ok((hdr, files))
158 }
159
160 /// Receive the body following the header `hdr`.
recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>>161 pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>> {
162 // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
163 // works as expected.
164 let mut body = vec![0; hdr.get_size().try_into().unwrap()];
165 let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
166 if !files.is_empty() {
167 return Err(Error::InvalidMessage);
168 }
169 Ok(body)
170 }
171
172 /// Receive a message header and body.
173 ///
174 /// Errors if the header or body is invalid.
175 ///
176 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
177 /// accepted and all other file descriptor will be discard silently.
recv_message<T: IntoBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)>178 pub fn recv_message<T: IntoBytes + FromBytes + VhostUserMsgValidator>(
179 &self,
180 ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
181 let mut hdr_raw = [0u32; 3];
182 let mut body = T::new_zeroed();
183 let mut slices = [hdr_raw.as_mut_bytes(), body.as_mut_bytes()];
184 let files = self.recv_into_bufs_all(&mut slices)?;
185
186 let hdr = VhostUserMsgHeader::from_raw(hdr_raw);
187 if !hdr.is_valid() || !body.is_valid() {
188 return Err(Error::InvalidMessage);
189 }
190
191 Ok((hdr, body, files))
192 }
193
194 /// Receive a message header and body, where the body includes a variable length payload at the
195 /// end.
196 ///
197 /// Errors if the header or body is invalid.
198 ///
199 /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
200 /// other file descriptor will be discard silently.
recv_message_with_payload<T: IntoBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)>201 pub fn recv_message_with_payload<T: IntoBytes + FromBytes + VhostUserMsgValidator>(
202 &self,
203 ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)> {
204 let (hdr, files) = self.recv_header()?;
205
206 let mut body = T::new_zeroed();
207 let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
208 let mut buf: Vec<u8> = vec![0; payload_size];
209 let mut slices = [body.as_mut_bytes(), buf.as_mut_bytes()];
210 let more_files = self.recv_into_bufs_all(&mut slices)?;
211 if !body.is_valid() || !more_files.is_empty() {
212 return Err(Error::InvalidMessage);
213 }
214
215 Ok((hdr, body, buf, files))
216 }
217 }
218
219 impl<R: Req> AsRawDescriptor for Connection<R> {
as_raw_descriptor(&self) -> RawDescriptor220 fn as_raw_descriptor(&self) -> RawDescriptor {
221 self.0.as_raw_descriptor()
222 }
223 }
224
225 #[cfg(test)]
226 pub(crate) mod tests {
227 use std::io::Read;
228 use std::io::Seek;
229 use std::io::SeekFrom;
230 use std::io::Write;
231
232 use tempfile::tempfile;
233
234 use super::*;
235 use crate::message::VhostUserEmptyMessage;
236 use crate::message::VhostUserU64;
237
238 #[test]
send_header_only()239 fn send_header_only() {
240 let (client_connection, server_connection) = Connection::pair().unwrap();
241 let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
242 client_connection
243 .send_header_only_message(&hdr1, None)
244 .unwrap();
245 let (hdr2, _, files) = server_connection
246 .recv_message::<VhostUserEmptyMessage>()
247 .unwrap();
248 assert_eq!(hdr1, hdr2);
249 assert!(files.is_empty());
250 }
251
252 #[test]
send_data()253 fn send_data() {
254 let (client_connection, server_connection) = Connection::pair().unwrap();
255 let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
256 client_connection
257 .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
258 .unwrap();
259 let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
260 assert_eq!(hdr1, hdr2);
261 let value = body.value;
262 assert_eq!(value, 0xf00dbeefdeadf00d);
263 assert!(files.is_empty());
264 }
265
266 #[test]
send_fd()267 fn send_fd() {
268 let (client_connection, server_connection) = Connection::pair().unwrap();
269
270 let mut fd = tempfile().unwrap();
271 write!(fd, "test").unwrap();
272
273 // Normal case for sending/receiving file descriptors
274 let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
275 client_connection
276 .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
277 .unwrap();
278
279 let (hdr2, _, files) = server_connection
280 .recv_message::<VhostUserEmptyMessage>()
281 .unwrap();
282 assert_eq!(hdr1, hdr2);
283 assert_eq!(files.len(), 1);
284 let mut file = &files[0];
285 let mut content = String::new();
286 file.seek(SeekFrom::Start(0)).unwrap();
287 file.read_to_string(&mut content).unwrap();
288 assert_eq!(content, "test");
289 }
290
291 #[test]
test_advance_slices_mut()292 fn test_advance_slices_mut() {
293 // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
294 let mut buf1 = [1; 8];
295 let mut buf2 = [2; 16];
296 let mut buf3 = [3; 8];
297 let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
298 advance_slices_mut(&mut bufs, 10);
299 assert_eq!(bufs[0], [2; 14].as_ref());
300 assert_eq!(bufs[1], [3; 8].as_ref());
301 }
302 }
303