• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 cfg_if::cfg_if! {
5     if #[cfg(unix)] {
6         mod unix;
7     } else if #[cfg(windows)] {
8         mod windows;
9     }
10 }
11 
12 use std::fs::File;
13 use std::mem;
14 use std::sync::Arc;
15 use std::sync::Mutex;
16 
17 use base::AsRawDescriptor;
18 use base::SafeDescriptor;
19 
20 use crate::connection::EndpointExt;
21 use crate::message::*;
22 use crate::Error;
23 use crate::HandlerResult;
24 use crate::Result;
25 use crate::SlaveReqEndpoint;
26 use crate::SystemStream;
27 
28 /// Define services provided by masters for the slave communication channel.
29 ///
30 /// The vhost-user specification defines a slave communication channel, by which slaves could
31 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
32 /// by masters, and it's used both on the master side and slave side.
33 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
34 ///   service requests to masters. The [Slave] is an example stub forwarder.
35 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler
36 ///   implementing [VhostUserMasterReqHandler].
37 ///
38 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
39 /// for multi-threading.
40 ///
41 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
42 /// [MasterReqHandler]: struct.MasterReqHandler.html
43 /// [Slave]: struct.Slave.html
44 pub trait VhostUserMasterReqHandler {
45     /// Handle device configuration change notifications.
handle_config_change(&self) -> HandlerResult<u64>46     fn handle_config_change(&self) -> HandlerResult<u64> {
47         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
48     }
49 
50     /// Handle shared memory region mapping requests.
shmem_map( &self, _req: &VhostUserShmemMapMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>51     fn shmem_map(
52         &self,
53         _req: &VhostUserShmemMapMsg,
54         _fd: &dyn AsRawDescriptor,
55     ) -> HandlerResult<u64> {
56         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
57     }
58 
59     /// Handle shared memory region unmapping requests.
shmem_unmap(&self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>60     fn shmem_unmap(&self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> {
61         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
62     }
63 
64     /// Handle virtio-fs map file requests.
fs_slave_map( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>65     fn fs_slave_map(
66         &self,
67         _fs: &VhostUserFSSlaveMsg,
68         _fd: &dyn AsRawDescriptor,
69     ) -> HandlerResult<u64> {
70         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
71     }
72 
73     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>74     fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
75         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
76     }
77 
78     /// Handle virtio-fs sync file requests.
fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>79     fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
80         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
81     }
82 
83     /// Handle virtio-fs file IO requests.
fs_slave_io( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>84     fn fs_slave_io(
85         &self,
86         _fs: &VhostUserFSSlaveMsg,
87         _fd: &dyn AsRawDescriptor,
88     ) -> HandlerResult<u64> {
89         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
90     }
91 
92     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
93     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawDescriptor);
94 
95     /// Handle GPU shared memory region mapping requests.
gpu_map( &self, _req: &VhostUserGpuMapMsg, _descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>96     fn gpu_map(
97         &self,
98         _req: &VhostUserGpuMapMsg,
99         _descriptor: &dyn AsRawDescriptor,
100     ) -> HandlerResult<u64> {
101         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
102     }
103 }
104 
105 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
106 ///
107 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
108 pub trait VhostUserMasterReqHandlerMut {
109     /// Handle device configuration change notifications.
handle_config_change(&mut self) -> HandlerResult<u64>110     fn handle_config_change(&mut self) -> HandlerResult<u64> {
111         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
112     }
113 
114     /// Handle shared memory region mapping requests.
shmem_map( &mut self, _req: &VhostUserShmemMapMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>115     fn shmem_map(
116         &mut self,
117         _req: &VhostUserShmemMapMsg,
118         _fd: &dyn AsRawDescriptor,
119     ) -> HandlerResult<u64> {
120         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
121     }
122 
123     /// Handle shared memory region unmapping requests.
shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>124     fn shmem_unmap(&mut self, _req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> {
125         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
126     }
127 
128     /// Handle virtio-fs map file requests.
fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>129     fn fs_slave_map(
130         &mut self,
131         _fs: &VhostUserFSSlaveMsg,
132         _fd: &dyn AsRawDescriptor,
133     ) -> HandlerResult<u64> {
134         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
135     }
136 
137     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>138     fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
139         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
140     }
141 
142     /// Handle virtio-fs sync file requests.
fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>143     fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
144         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
145     }
146 
147     /// Handle virtio-fs file IO requests.
fs_slave_io( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>148     fn fs_slave_io(
149         &mut self,
150         _fs: &VhostUserFSSlaveMsg,
151         _fd: &dyn AsRawDescriptor,
152     ) -> HandlerResult<u64> {
153         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
154     }
155 
156     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
157     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawDescriptor);
158 
159     /// Handle GPU shared memory region mapping requests.
gpu_map( &mut self, _req: &VhostUserGpuMapMsg, _descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>160     fn gpu_map(
161         &mut self,
162         _req: &VhostUserGpuMapMsg,
163         _descriptor: &dyn AsRawDescriptor,
164     ) -> HandlerResult<u64> {
165         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
166     }
167 }
168 
169 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
handle_config_change(&self) -> HandlerResult<u64>170     fn handle_config_change(&self) -> HandlerResult<u64> {
171         self.lock().unwrap().handle_config_change()
172     }
173 
shmem_map( &self, req: &VhostUserShmemMapMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>174     fn shmem_map(
175         &self,
176         req: &VhostUserShmemMapMsg,
177         fd: &dyn AsRawDescriptor,
178     ) -> HandlerResult<u64> {
179         self.lock().unwrap().shmem_map(req, fd)
180     }
181 
shmem_unmap(&self, req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>182     fn shmem_unmap(&self, req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> {
183         self.lock().unwrap().shmem_unmap(req)
184     }
185 
fs_slave_map( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>186     fn fs_slave_map(
187         &self,
188         fs: &VhostUserFSSlaveMsg,
189         fd: &dyn AsRawDescriptor,
190     ) -> HandlerResult<u64> {
191         self.lock().unwrap().fs_slave_map(fs, fd)
192     }
193 
fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>194     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
195         self.lock().unwrap().fs_slave_unmap(fs)
196     }
197 
fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>198     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
199         self.lock().unwrap().fs_slave_sync(fs)
200     }
201 
fs_slave_io( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>202     fn fs_slave_io(
203         &self,
204         fs: &VhostUserFSSlaveMsg,
205         fd: &dyn AsRawDescriptor,
206     ) -> HandlerResult<u64> {
207         self.lock().unwrap().fs_slave_io(fs, fd)
208     }
209 
gpu_map( &self, req: &VhostUserGpuMapMsg, descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>210     fn gpu_map(
211         &self,
212         req: &VhostUserGpuMapMsg,
213         descriptor: &dyn AsRawDescriptor,
214     ) -> HandlerResult<u64> {
215         self.lock().unwrap().gpu_map(req, descriptor)
216     }
217 }
218 
219 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
220 /// slaves on the slave communication channel. It's actually a proxy invoking the registered
221 /// handler implementing [VhostUserMasterReqHandler] to do the real work.
222 ///
223 /// [MasterReqHandler]: struct.MasterReqHandler.html
224 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
225 ///
226 /// Server to handle service requests from slaves from the slave communication channel.
227 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
228     // underlying Unix domain socket for communication
229     sub_sock: SlaveReqEndpoint,
230     tx_sock: Option<SystemStream>,
231     // Serializes tx_sock for passing to the backend.
232     serialize_tx: Box<dyn Fn(SystemStream) -> SafeDescriptor + Send>,
233     // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
234     reply_ack_negotiated: bool,
235 
236     /// the VirtIO backend device object
237     backend: Arc<S>,
238 }
239 
240 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
241     /// Create a server to handle service requests from slaves on the slave communication channel.
242     ///
243     /// This opens a pair of connected anonymous sockets to form the slave communication channel.
244     /// The socket fd returned by [Self::take_tx_descriptor()] should be sent to the slave by
245     /// [VhostUserMaster::set_slave_request_fd()].
246     ///
247     /// [Self::take_tx_descriptor()]: struct.MasterReqHandler.html#method.take_tx_descriptor
248     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
new( backend: Arc<S>, serialize_tx: Box<dyn Fn(SystemStream) -> SafeDescriptor + Send>, ) -> Result<Self>249     pub fn new(
250         backend: Arc<S>,
251         serialize_tx: Box<dyn Fn(SystemStream) -> SafeDescriptor + Send>,
252     ) -> Result<Self> {
253         let (tx, rx) = SystemStream::pair()?;
254 
255         Ok(MasterReqHandler {
256             sub_sock: SlaveReqEndpoint::from(rx),
257             tx_sock: Some(tx),
258             serialize_tx,
259             reply_ack_negotiated: false,
260             backend,
261         })
262     }
263 
264     /// Get the descriptor for the slave to communication with the master.
265     ///
266     /// The caller owns the descriptor. The returned descriptor should be sent to the slave by
267     /// [VhostUserMaster::set_slave_request_fd()].
268     ///
269     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
take_tx_descriptor(&mut self) -> SafeDescriptor270     pub fn take_tx_descriptor(&mut self) -> SafeDescriptor {
271         (self.serialize_tx)(self.tx_sock.take().expect("tx_sock should have a value"))
272     }
273 
274     /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
275     ///
276     /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
277     /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
278     /// message.
set_reply_ack_flag(&mut self, enable: bool)279     pub fn set_reply_ack_flag(&mut self, enable: bool) {
280         self.reply_ack_negotiated = enable;
281     }
282 
283     /// Get the underlying backend device
backend(&self) -> Arc<S>284     pub fn backend(&self) -> Arc<S> {
285         Arc::clone(&self.backend)
286     }
287 
288     /// Main entrance to server slave request from the slave communication channel.
289     ///
290     /// The caller needs to:
291     /// - serialize calls to this function
292     /// - decide what to do when errer happens
293     /// - optional recover from failure
handle_request(&mut self) -> Result<u64>294     pub fn handle_request(&mut self) -> Result<u64> {
295         // The underlying communication channel is a Unix domain socket in
296         // stream mode, and recvmsg() is a little tricky here. To successfully
297         // receive attached file descriptors, we need to receive messages and
298         // corresponding attached file descriptors in this way:
299         // . recv messsage header and optional attached file
300         // . validate message header
301         // . recv optional message body and payload according size field in
302         //   message header
303         // . validate message body and optional payload
304         let (hdr, files) = self.sub_sock.recv_header()?;
305         self.check_attached_files(&hdr, &files)?;
306         let buf = match hdr.get_size() {
307             0 => vec![0u8; 0],
308             len => {
309                 if len as usize > MAX_MSG_SIZE {
310                     return Err(Error::InvalidMessage);
311                 }
312                 let rbuf = self.sub_sock.recv_data(len as usize)?;
313                 if rbuf.len() != len as usize {
314                     return Err(Error::InvalidMessage);
315                 }
316                 rbuf
317             }
318         };
319         let size = buf.len();
320 
321         let res = match hdr.get_code() {
322             SlaveReq::CONFIG_CHANGE_MSG => {
323                 self.check_msg_size(&hdr, size, 0)?;
324                 self.backend
325                     .handle_config_change()
326                     .map_err(Error::ReqHandlerError)
327             }
328             SlaveReq::SHMEM_MAP => {
329                 let msg = self.extract_msg_body::<VhostUserShmemMapMsg>(&hdr, size, &buf)?;
330                 // check_attached_files() has validated files
331                 self.backend
332                     .shmem_map(&msg, &files.unwrap()[0])
333                     .map_err(Error::ReqHandlerError)
334             }
335             SlaveReq::SHMEM_UNMAP => {
336                 let msg = self.extract_msg_body::<VhostUserShmemUnmapMsg>(&hdr, size, &buf)?;
337                 self.backend
338                     .shmem_unmap(&msg)
339                     .map_err(Error::ReqHandlerError)
340             }
341             SlaveReq::FS_MAP => {
342                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
343                 // check_attached_files() has validated files
344                 self.backend
345                     .fs_slave_map(&msg, &files.unwrap()[0])
346                     .map_err(Error::ReqHandlerError)
347             }
348             SlaveReq::FS_UNMAP => {
349                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
350                 self.backend
351                     .fs_slave_unmap(&msg)
352                     .map_err(Error::ReqHandlerError)
353             }
354             SlaveReq::FS_SYNC => {
355                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
356                 self.backend
357                     .fs_slave_sync(&msg)
358                     .map_err(Error::ReqHandlerError)
359             }
360             SlaveReq::FS_IO => {
361                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
362                 // check_attached_files() has validated files
363                 self.backend
364                     .fs_slave_io(&msg, &files.unwrap()[0])
365                     .map_err(Error::ReqHandlerError)
366             }
367             SlaveReq::GPU_MAP => {
368                 let msg = self.extract_msg_body::<VhostUserGpuMapMsg>(&hdr, size, &buf)?;
369                 // check_attached_files() has validated files
370                 self.backend
371                     .gpu_map(&msg, &files.unwrap()[0])
372                     .map_err(Error::ReqHandlerError)
373             }
374             _ => Err(Error::InvalidMessage),
375         };
376 
377         self.send_reply(&hdr, &res)?;
378 
379         res
380     }
381 
check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>382     fn check_msg_size(
383         &self,
384         hdr: &VhostUserMsgHeader<SlaveReq>,
385         size: usize,
386         expected: usize,
387     ) -> Result<()> {
388         if hdr.get_size() as usize != expected
389             || hdr.is_reply()
390             || hdr.get_version() != 0x1
391             || size != expected
392         {
393             return Err(Error::InvalidMessage);
394         }
395         Ok(())
396     }
397 
check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, files: &Option<Vec<File>>, ) -> Result<()>398     fn check_attached_files(
399         &self,
400         hdr: &VhostUserMsgHeader<SlaveReq>,
401         files: &Option<Vec<File>>,
402     ) -> Result<()> {
403         match hdr.get_code() {
404             SlaveReq::SHMEM_MAP | SlaveReq::FS_MAP | SlaveReq::FS_IO | SlaveReq::GPU_MAP => {
405                 // Expect a single file is passed.
406                 match files {
407                     Some(files) if files.len() == 1 => Ok(()),
408                     _ => Err(Error::InvalidMessage),
409                 }
410             }
411             _ if files.is_some() => Err(Error::InvalidMessage),
412             _ => Ok(()),
413         }
414     }
415 
extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>416     fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
417         &self,
418         hdr: &VhostUserMsgHeader<SlaveReq>,
419         size: usize,
420         buf: &[u8],
421     ) -> Result<T> {
422         self.check_msg_size(hdr, size, mem::size_of::<T>())?;
423         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
424         if !msg.is_valid() {
425             return Err(Error::InvalidMessage);
426         }
427         Ok(msg)
428     }
429 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>430     fn new_reply_header<T: Sized>(
431         &self,
432         req: &VhostUserMsgHeader<SlaveReq>,
433     ) -> Result<VhostUserMsgHeader<SlaveReq>> {
434         if mem::size_of::<T>() > MAX_MSG_SIZE {
435             return Err(Error::InvalidParam);
436         }
437         Ok(VhostUserMsgHeader::new(
438             req.get_code(),
439             VhostUserHeaderFlag::REPLY.bits(),
440             mem::size_of::<T>() as u32,
441         ))
442     }
443 
send_reply(&mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>) -> Result<()>444     fn send_reply(&mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>) -> Result<()> {
445         if req.get_code() == SlaveReq::SHMEM_MAP
446             || req.get_code() == SlaveReq::SHMEM_UNMAP
447             || req.get_code() == SlaveReq::GPU_MAP
448             || (self.reply_ack_negotiated && req.is_need_reply())
449         {
450             let hdr = self.new_reply_header::<VhostUserU64>(req)?;
451             let def_err = libc::EINVAL;
452             let val = match res {
453                 Ok(n) => *n,
454                 Err(e) => match e {
455                     Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
456                         Some(rawerr) => -rawerr as u64,
457                         None => -def_err as u64,
458                     },
459                     _ => -def_err as u64,
460                 },
461             };
462             let msg = VhostUserU64::new(val);
463             self.sub_sock.send_message(&hdr, &msg, None)?;
464         }
465         Ok(())
466     }
467 }
468