• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::mem;
5 use std::os::unix::io::{AsRawFd, RawFd};
6 use std::os::unix::net::UnixStream;
7 use std::sync::{Arc, Mutex};
8 
9 use super::connection::Endpoint;
10 use super::message::*;
11 use super::{Error, HandlerResult, Result};
12 
13 /// Define services provided by masters for the slave communication channel.
14 ///
15 /// The vhost-user specification defines a slave communication channel, by which slaves could
16 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
17 /// by masters, and it's used both on the master side and slave side.
18 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
19 ///   service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
20 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler
21 ///   implementing [VhostUserMasterReqHandler].
22 ///
23 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
24 /// for multi-threading.
25 ///
26 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
27 /// [MasterReqHandler]: struct.MasterReqHandler.html
28 /// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
29 pub trait VhostUserMasterReqHandler {
30     /// Handle device configuration change notifications.
handle_config_change(&self) -> HandlerResult<u64>31     fn handle_config_change(&self) -> HandlerResult<u64> {
32         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
33     }
34 
35     /// Handle virtio-fs map file requests.
fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>36     fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
37         // Safe because we have just received the rawfd from kernel.
38         unsafe { libc::close(fd) };
39         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
40     }
41 
42     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>43     fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
44         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
45     }
46 
47     /// Handle virtio-fs sync file requests.
fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>48     fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
49         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
50     }
51 
52     /// Handle virtio-fs file IO requests.
fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>53     fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
54         // Safe because we have just received the rawfd from kernel.
55         unsafe { libc::close(fd) };
56         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
57     }
58 
59     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
60     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
61 }
62 
63 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
64 ///
65 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
66 pub trait VhostUserMasterReqHandlerMut {
67     /// Handle device configuration change notifications.
handle_config_change(&mut self) -> HandlerResult<u64>68     fn handle_config_change(&mut self) -> HandlerResult<u64> {
69         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
70     }
71 
72     /// Handle virtio-fs map file requests.
fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>73     fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
74         // Safe because we have just received the rawfd from kernel.
75         unsafe { libc::close(fd) };
76         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
77     }
78 
79     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>80     fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
81         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
82     }
83 
84     /// Handle virtio-fs sync file requests.
fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>85     fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
86         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
87     }
88 
89     /// Handle virtio-fs file IO requests.
fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>90     fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
91         // Safe because we have just received the rawfd from kernel.
92         unsafe { libc::close(fd) };
93         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
94     }
95 
96     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
97     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
98 }
99 
100 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
handle_config_change(&self) -> HandlerResult<u64>101     fn handle_config_change(&self) -> HandlerResult<u64> {
102         self.lock().unwrap().handle_config_change()
103     }
104 
fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>105     fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
106         self.lock().unwrap().fs_slave_map(fs, fd)
107     }
108 
fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>109     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
110         self.lock().unwrap().fs_slave_unmap(fs)
111     }
112 
fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>113     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
114         self.lock().unwrap().fs_slave_sync(fs)
115     }
116 
fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>117     fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
118         self.lock().unwrap().fs_slave_io(fs, fd)
119     }
120 }
121 
122 /// Server to handle service requests from slaves from the slave communication channel.
123 ///
124 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
125 /// slaves on the slave communication channel. It's actually a proxy invoking the registered
126 /// handler implementing [VhostUserMasterReqHandler] to do the real work.
127 ///
128 /// [MasterReqHandler]: struct.MasterReqHandler.html
129 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
130 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
131     // underlying Unix domain socket for communication
132     sub_sock: Endpoint<SlaveReq>,
133     tx_sock: UnixStream,
134     // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
135     reply_ack_negotiated: bool,
136     // the VirtIO backend device object
137     backend: Arc<S>,
138     // whether the endpoint has encountered any failure
139     error: Option<i32>,
140 }
141 
142 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
143     /// Create a server to handle service requests from slaves on the slave communication channel.
144     ///
145     /// This opens a pair of connected anonymous sockets to form the slave communication channel.
146     /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
147     /// [VhostUserMaster::set_slave_request_fd()].
148     ///
149     /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
150     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
new(backend: Arc<S>) -> Result<Self>151     pub fn new(backend: Arc<S>) -> Result<Self> {
152         let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
153 
154         Ok(MasterReqHandler {
155             sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
156             tx_sock: tx,
157             reply_ack_negotiated: false,
158             backend,
159             error: None,
160         })
161     }
162 
163     /// Get the socket fd for the slave to communication with the master.
164     ///
165     /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
166     ///
167     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
get_tx_raw_fd(&self) -> RawFd168     pub fn get_tx_raw_fd(&self) -> RawFd {
169         self.tx_sock.as_raw_fd()
170     }
171 
172     /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
173     ///
174     /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
175     /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
176     /// message.
set_reply_ack_flag(&mut self, enable: bool)177     pub fn set_reply_ack_flag(&mut self, enable: bool) {
178         self.reply_ack_negotiated = enable;
179     }
180 
181     /// Mark endpoint as failed or in normal state.
set_failed(&mut self, error: i32)182     pub fn set_failed(&mut self, error: i32) {
183         if error == 0 {
184             self.error = None;
185         } else {
186             self.error = Some(error);
187         }
188     }
189 
190     /// Main entrance to server slave request from the slave communication channel.
191     ///
192     /// The caller needs to:
193     /// - serialize calls to this function
194     /// - decide what to do when errer happens
195     /// - optional recover from failure
handle_request(&mut self) -> Result<u64>196     pub fn handle_request(&mut self) -> Result<u64> {
197         // Return error if the endpoint is already in failed state.
198         self.check_state()?;
199 
200         // The underlying communication channel is a Unix domain socket in
201         // stream mode, and recvmsg() is a little tricky here. To successfully
202         // receive attached file descriptors, we need to receive messages and
203         // corresponding attached file descriptors in this way:
204         // . recv messsage header and optional attached file
205         // . validate message header
206         // . recv optional message body and payload according size field in
207         //   message header
208         // . validate message body and optional payload
209         let (hdr, rfds) = self.sub_sock.recv_header()?;
210         let rfds = self.check_attached_rfds(&hdr, rfds)?;
211         let (size, buf) = match hdr.get_size() {
212             0 => (0, vec![0u8; 0]),
213             len => {
214                 if len as usize > MAX_MSG_SIZE {
215                     return Err(Error::InvalidMessage);
216                 }
217                 let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
218                 if size2 != len as usize {
219                     return Err(Error::InvalidMessage);
220                 }
221                 (size2, rbuf)
222             }
223         };
224 
225         let res = match hdr.get_code() {
226             SlaveReq::CONFIG_CHANGE_MSG => {
227                 self.check_msg_size(&hdr, size, 0)?;
228                 self.backend
229                     .handle_config_change()
230                     .map_err(Error::ReqHandlerError)
231             }
232             SlaveReq::FS_MAP => {
233                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
234                 // check_attached_rfds() has validated rfds
235                 self.backend
236                     .fs_slave_map(&msg, rfds.unwrap()[0])
237                     .map_err(Error::ReqHandlerError)
238             }
239             SlaveReq::FS_UNMAP => {
240                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
241                 self.backend
242                     .fs_slave_unmap(&msg)
243                     .map_err(Error::ReqHandlerError)
244             }
245             SlaveReq::FS_SYNC => {
246                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
247                 self.backend
248                     .fs_slave_sync(&msg)
249                     .map_err(Error::ReqHandlerError)
250             }
251             SlaveReq::FS_IO => {
252                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
253                 // check_attached_rfds() has validated rfds
254                 self.backend
255                     .fs_slave_io(&msg, rfds.unwrap()[0])
256                     .map_err(Error::ReqHandlerError)
257             }
258             _ => Err(Error::InvalidMessage),
259         };
260 
261         self.send_ack_message(&hdr, &res)?;
262 
263         res
264     }
265 
check_state(&self) -> Result<()>266     fn check_state(&self) -> Result<()> {
267         match self.error {
268             Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
269             None => Ok(()),
270         }
271     }
272 
check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>273     fn check_msg_size(
274         &self,
275         hdr: &VhostUserMsgHeader<SlaveReq>,
276         size: usize,
277         expected: usize,
278     ) -> Result<()> {
279         if hdr.get_size() as usize != expected
280             || hdr.is_reply()
281             || hdr.get_version() != 0x1
282             || size != expected
283         {
284             return Err(Error::InvalidMessage);
285         }
286         Ok(())
287     }
288 
check_attached_rfds( &self, hdr: &VhostUserMsgHeader<SlaveReq>, rfds: Option<Vec<RawFd>>, ) -> Result<Option<Vec<RawFd>>>289     fn check_attached_rfds(
290         &self,
291         hdr: &VhostUserMsgHeader<SlaveReq>,
292         rfds: Option<Vec<RawFd>>,
293     ) -> Result<Option<Vec<RawFd>>> {
294         match hdr.get_code() {
295             SlaveReq::FS_MAP | SlaveReq::FS_IO => {
296                 // Expect an fd set with a single fd.
297                 match rfds {
298                     None => Err(Error::InvalidMessage),
299                     Some(fds) => {
300                         if fds.len() != 1 {
301                             Endpoint::<SlaveReq>::close_rfds(Some(fds));
302                             Err(Error::InvalidMessage)
303                         } else {
304                             Ok(Some(fds))
305                         }
306                     }
307                 }
308             }
309             _ => {
310                 if rfds.is_some() {
311                     Endpoint::<SlaveReq>::close_rfds(rfds);
312                     Err(Error::InvalidMessage)
313                 } else {
314                     Ok(rfds)
315                 }
316             }
317         }
318     }
319 
extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>320     fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
321         &self,
322         hdr: &VhostUserMsgHeader<SlaveReq>,
323         size: usize,
324         buf: &[u8],
325     ) -> Result<T> {
326         self.check_msg_size(hdr, size, mem::size_of::<T>())?;
327         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
328         if !msg.is_valid() {
329             return Err(Error::InvalidMessage);
330         }
331         Ok(msg)
332     }
333 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>334     fn new_reply_header<T: Sized>(
335         &self,
336         req: &VhostUserMsgHeader<SlaveReq>,
337     ) -> Result<VhostUserMsgHeader<SlaveReq>> {
338         if mem::size_of::<T>() > MAX_MSG_SIZE {
339             return Err(Error::InvalidParam);
340         }
341         self.check_state()?;
342         Ok(VhostUserMsgHeader::new(
343             req.get_code(),
344             VhostUserHeaderFlag::REPLY.bits(),
345             mem::size_of::<T>() as u32,
346         ))
347     }
348 
send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>349     fn send_ack_message(
350         &mut self,
351         req: &VhostUserMsgHeader<SlaveReq>,
352         res: &Result<u64>,
353     ) -> Result<()> {
354         if self.reply_ack_negotiated && req.is_need_reply() {
355             let hdr = self.new_reply_header::<VhostUserU64>(req)?;
356             let def_err = libc::EINVAL;
357             let val = match res {
358                 Ok(n) => *n,
359                 Err(e) => match &*e {
360                     Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
361                         Some(rawerr) => -rawerr as u64,
362                         None => -def_err as u64,
363                     },
364                     _ => -def_err as u64,
365                 },
366             };
367             let msg = VhostUserU64::new(val);
368             self.sub_sock.send_message(&hdr, &msg, None)?;
369         }
370         Ok(())
371     }
372 }
373 
374 impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
as_raw_fd(&self) -> RawFd375     fn as_raw_fd(&self) -> RawFd {
376         self.sub_sock.as_raw_fd()
377     }
378 }
379 
380 #[cfg(test)]
381 mod tests {
382     use super::*;
383 
384     #[cfg(feature = "vhost-user-slave")]
385     use crate::vhost_user::SlaveFsCacheReq;
386     #[cfg(feature = "vhost-user-slave")]
387     use std::os::unix::io::FromRawFd;
388 
389     struct MockMasterReqHandler {}
390 
391     impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
392         /// Handle virtio-fs map file requests from the slave.
fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>393         fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
394             // Safe because we have just received the rawfd from kernel.
395             unsafe { libc::close(fd) };
396             Ok(0)
397         }
398 
399         /// Handle virtio-fs unmap file requests from the slave.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>400         fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
401             Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
402         }
403     }
404 
405     #[test]
test_new_master_req_handler()406     fn test_new_master_req_handler() {
407         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
408         let mut handler = MasterReqHandler::new(backend).unwrap();
409 
410         assert!(handler.get_tx_raw_fd() >= 0);
411         assert!(handler.as_raw_fd() >= 0);
412         handler.check_state().unwrap();
413 
414         assert_eq!(handler.error, None);
415         handler.set_failed(libc::EAGAIN);
416         assert_eq!(handler.error, Some(libc::EAGAIN));
417         handler.check_state().unwrap_err();
418     }
419 
420     #[cfg(feature = "vhost-user-slave")]
421     #[test]
test_master_slave_req_handler()422     fn test_master_slave_req_handler() {
423         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
424         let mut handler = MasterReqHandler::new(backend).unwrap();
425 
426         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
427         if fd < 0 {
428             panic!("failed to duplicated tx fd!");
429         }
430         let stream = unsafe { UnixStream::from_raw_fd(fd) };
431         let fs_cache = SlaveFsCacheReq::from_stream(stream);
432 
433         std::thread::spawn(move || {
434             let res = handler.handle_request().unwrap();
435             assert_eq!(res, 0);
436             handler.handle_request().unwrap_err();
437         });
438 
439         fs_cache
440             .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
441             .unwrap();
442         // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
443         // slave side.
444         fs_cache
445             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
446             .unwrap();
447     }
448 
449     #[cfg(feature = "vhost-user-slave")]
450     #[test]
test_master_slave_req_handler_with_ack()451     fn test_master_slave_req_handler_with_ack() {
452         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
453         let mut handler = MasterReqHandler::new(backend).unwrap();
454         handler.set_reply_ack_flag(true);
455 
456         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
457         if fd < 0 {
458             panic!("failed to duplicated tx fd!");
459         }
460         let stream = unsafe { UnixStream::from_raw_fd(fd) };
461         let fs_cache = SlaveFsCacheReq::from_stream(stream);
462 
463         std::thread::spawn(move || {
464             let res = handler.handle_request().unwrap();
465             assert_eq!(res, 0);
466             handler.handle_request().unwrap_err();
467         });
468 
469         fs_cache.set_reply_ack_flag(true);
470         fs_cache
471             .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
472             .unwrap();
473         fs_cache
474             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
475             .unwrap_err();
476     }
477 }
478