• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #[cfg(unix)]
5 use std::fs::File;
6 #[cfg(unix)]
7 use std::mem;
8 #[cfg(unix)]
9 use std::os::unix::io::AsRawFd;
10 use std::sync::Mutex;
11 
12 #[cfg(unix)]
13 use super::connection::{socket::Endpoint as SocketEndpoint, EndpointExt};
14 use super::message::*;
15 use super::HandlerResult;
16 #[cfg(unix)]
17 use super::{Error, Result};
18 #[cfg(unix)]
19 use crate::SystemStream;
20 #[cfg(unix)]
21 use std::sync::Arc;
22 
23 use base::AsRawDescriptor;
24 #[cfg(unix)]
25 use base::RawDescriptor;
26 
27 /// Define services provided by masters for the slave communication channel.
28 ///
29 /// The vhost-user specification defines a slave communication channel, by which slaves could
30 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
31 /// by masters, and it's used both on the master side and slave side.
32 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
33 ///   service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
34 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler
35 ///   implementing [VhostUserMasterReqHandler].
36 ///
37 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
38 /// for multi-threading.
39 ///
40 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
41 /// [MasterReqHandler]: struct.MasterReqHandler.html
42 /// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
43 pub trait VhostUserMasterReqHandler {
44     /// Handle device configuration change notifications.
handle_config_change(&self) -> HandlerResult<u64>45     fn handle_config_change(&self) -> HandlerResult<u64> {
46         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
47     }
48 
49     /// Handle virtio-fs map file requests.
fs_slave_map( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>50     fn fs_slave_map(
51         &self,
52         _fs: &VhostUserFSSlaveMsg,
53         _fd: &dyn AsRawDescriptor,
54     ) -> HandlerResult<u64> {
55         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
56     }
57 
58     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>59     fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
60         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
61     }
62 
63     /// Handle virtio-fs sync file requests.
fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>64     fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
65         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
66     }
67 
68     /// Handle virtio-fs file IO requests.
fs_slave_io( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>69     fn fs_slave_io(
70         &self,
71         _fs: &VhostUserFSSlaveMsg,
72         _fd: &dyn AsRawDescriptor,
73     ) -> HandlerResult<u64> {
74         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
75     }
76 
77     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
78     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawDescriptor);
79 }
80 
81 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
82 ///
83 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
84 pub trait VhostUserMasterReqHandlerMut {
85     /// Handle device configuration change notifications.
handle_config_change(&mut self) -> HandlerResult<u64>86     fn handle_config_change(&mut self) -> HandlerResult<u64> {
87         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
88     }
89 
90     /// Handle virtio-fs map file requests.
fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>91     fn fs_slave_map(
92         &mut self,
93         _fs: &VhostUserFSSlaveMsg,
94         _fd: &dyn AsRawDescriptor,
95     ) -> HandlerResult<u64> {
96         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
97     }
98 
99     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>100     fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
101         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
102     }
103 
104     /// Handle virtio-fs sync file requests.
fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>105     fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
106         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
107     }
108 
109     /// Handle virtio-fs file IO requests.
fs_slave_io( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>110     fn fs_slave_io(
111         &mut self,
112         _fs: &VhostUserFSSlaveMsg,
113         _fd: &dyn AsRawDescriptor,
114     ) -> HandlerResult<u64> {
115         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
116     }
117 
118     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
119     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawDescriptor);
120 }
121 
122 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
handle_config_change(&self) -> HandlerResult<u64>123     fn handle_config_change(&self) -> HandlerResult<u64> {
124         self.lock().unwrap().handle_config_change()
125     }
126 
fs_slave_map( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>127     fn fs_slave_map(
128         &self,
129         fs: &VhostUserFSSlaveMsg,
130         fd: &dyn AsRawDescriptor,
131     ) -> HandlerResult<u64> {
132         self.lock().unwrap().fs_slave_map(fs, fd)
133     }
134 
fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>135     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
136         self.lock().unwrap().fs_slave_unmap(fs)
137     }
138 
fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>139     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
140         self.lock().unwrap().fs_slave_sync(fs)
141     }
142 
fs_slave_io( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>143     fn fs_slave_io(
144         &self,
145         fs: &VhostUserFSSlaveMsg,
146         fd: &dyn AsRawDescriptor,
147     ) -> HandlerResult<u64> {
148         self.lock().unwrap().fs_slave_io(fs, fd)
149     }
150 }
151 
152 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
153 /// slaves on the slave communication channel. It's actually a proxy invoking the registered
154 /// handler implementing [VhostUserMasterReqHandler] to do the real work.
155 ///
156 /// [MasterReqHandler]: struct.MasterReqHandler.html
157 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
158 ///
159 /// TODO(b/221882601): we can write a version of this for Windows by switching the socket for a Tube.
160 /// The interfaces would need to change so that we fetch a full Tube (which is 2 rds on Windows)
161 /// and send it to the device backend (slave) as a message on the master -> slave channel.
162 /// (Currently the interface only supports sending a single rd.)
163 ///
164 /// Note that handling requests from slaves is not needed for the initial devices we plan to
165 /// support.
166 ///
167 /// Server to handle service requests from slaves from the slave communication channel.
168 #[cfg(unix)]
169 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
170     // underlying Unix domain socket for communication
171     sub_sock: SocketEndpoint<SlaveReq>,
172     tx_sock: SystemStream,
173     // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
174     reply_ack_negotiated: bool,
175     // the VirtIO backend device object
176     backend: Arc<S>,
177     // whether the endpoint has encountered any failure
178     error: Option<i32>,
179 }
180 
181 #[cfg(unix)]
182 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
183     /// Create a server to handle service requests from slaves on the slave communication channel.
184     ///
185     /// This opens a pair of connected anonymous sockets to form the slave communication channel.
186     /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
187     /// [VhostUserMaster::set_slave_request_fd()].
188     ///
189     /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
190     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
new(backend: Arc<S>) -> Result<Self>191     pub fn new(backend: Arc<S>) -> Result<Self> {
192         let (tx, rx) = SystemStream::pair().map_err(Error::SocketError)?;
193 
194         Ok(MasterReqHandler {
195             sub_sock: SocketEndpoint::<SlaveReq>::from(rx),
196             tx_sock: tx,
197             reply_ack_negotiated: false,
198             backend,
199             error: None,
200         })
201     }
202 
203     /// Get the socket fd for the slave to communication with the master.
204     ///
205     /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
206     ///
207     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
get_tx_raw_fd(&self) -> RawDescriptor208     pub fn get_tx_raw_fd(&self) -> RawDescriptor {
209         self.tx_sock.as_raw_fd()
210     }
211 
212     /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
213     ///
214     /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
215     /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
216     /// message.
set_reply_ack_flag(&mut self, enable: bool)217     pub fn set_reply_ack_flag(&mut self, enable: bool) {
218         self.reply_ack_negotiated = enable;
219     }
220 
221     /// Mark endpoint as failed or in normal state.
set_failed(&mut self, error: i32)222     pub fn set_failed(&mut self, error: i32) {
223         if error == 0 {
224             self.error = None;
225         } else {
226             self.error = Some(error);
227         }
228     }
229 
230     /// Main entrance to server slave request from the slave communication channel.
231     ///
232     /// The caller needs to:
233     /// - serialize calls to this function
234     /// - decide what to do when errer happens
235     /// - optional recover from failure
handle_request(&mut self) -> Result<u64>236     pub fn handle_request(&mut self) -> Result<u64> {
237         // Return error if the endpoint is already in failed state.
238         self.check_state()?;
239 
240         // The underlying communication channel is a Unix domain socket in
241         // stream mode, and recvmsg() is a little tricky here. To successfully
242         // receive attached file descriptors, we need to receive messages and
243         // corresponding attached file descriptors in this way:
244         // . recv messsage header and optional attached file
245         // . validate message header
246         // . recv optional message body and payload according size field in
247         //   message header
248         // . validate message body and optional payload
249         let (hdr, files) = self.sub_sock.recv_header()?;
250         self.check_attached_files(&hdr, &files)?;
251         let buf = match hdr.get_size() {
252             0 => vec![0u8; 0],
253             len => {
254                 if len as usize > MAX_MSG_SIZE {
255                     return Err(Error::InvalidMessage);
256                 }
257                 let rbuf = self.sub_sock.recv_data(len as usize)?;
258                 if rbuf.len() != len as usize {
259                     return Err(Error::InvalidMessage);
260                 }
261                 rbuf
262             }
263         };
264         let size = buf.len();
265 
266         let res = match hdr.get_code() {
267             SlaveReq::CONFIG_CHANGE_MSG => {
268                 self.check_msg_size(&hdr, size, 0)?;
269                 self.backend
270                     .handle_config_change()
271                     .map_err(Error::ReqHandlerError)
272             }
273             SlaveReq::FS_MAP => {
274                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
275                 // check_attached_files() has validated files
276                 self.backend
277                     .fs_slave_map(&msg, &files.unwrap()[0])
278                     .map_err(Error::ReqHandlerError)
279             }
280             SlaveReq::FS_UNMAP => {
281                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
282                 self.backend
283                     .fs_slave_unmap(&msg)
284                     .map_err(Error::ReqHandlerError)
285             }
286             SlaveReq::FS_SYNC => {
287                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
288                 self.backend
289                     .fs_slave_sync(&msg)
290                     .map_err(Error::ReqHandlerError)
291             }
292             SlaveReq::FS_IO => {
293                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
294                 // check_attached_files() has validated files
295                 self.backend
296                     .fs_slave_io(&msg, &files.unwrap()[0])
297                     .map_err(Error::ReqHandlerError)
298             }
299             _ => Err(Error::InvalidMessage),
300         };
301 
302         self.send_ack_message(&hdr, &res)?;
303 
304         res
305     }
306 
check_state(&self) -> Result<()>307     fn check_state(&self) -> Result<()> {
308         match self.error {
309             Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
310             None => Ok(()),
311         }
312     }
313 
check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>314     fn check_msg_size(
315         &self,
316         hdr: &VhostUserMsgHeader<SlaveReq>,
317         size: usize,
318         expected: usize,
319     ) -> Result<()> {
320         if hdr.get_size() as usize != expected
321             || hdr.is_reply()
322             || hdr.get_version() != 0x1
323             || size != expected
324         {
325             return Err(Error::InvalidMessage);
326         }
327         Ok(())
328     }
329 
check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, files: &Option<Vec<File>>, ) -> Result<()>330     fn check_attached_files(
331         &self,
332         hdr: &VhostUserMsgHeader<SlaveReq>,
333         files: &Option<Vec<File>>,
334     ) -> Result<()> {
335         match hdr.get_code() {
336             SlaveReq::FS_MAP | SlaveReq::FS_IO => {
337                 // Expect a single file is passed.
338                 match files {
339                     Some(files) if files.len() == 1 => Ok(()),
340                     _ => Err(Error::InvalidMessage),
341                 }
342             }
343             _ if files.is_some() => Err(Error::InvalidMessage),
344             _ => Ok(()),
345         }
346     }
347 
extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>348     fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
349         &self,
350         hdr: &VhostUserMsgHeader<SlaveReq>,
351         size: usize,
352         buf: &[u8],
353     ) -> Result<T> {
354         self.check_msg_size(hdr, size, mem::size_of::<T>())?;
355         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
356         if !msg.is_valid() {
357             return Err(Error::InvalidMessage);
358         }
359         Ok(msg)
360     }
361 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>362     fn new_reply_header<T: Sized>(
363         &self,
364         req: &VhostUserMsgHeader<SlaveReq>,
365     ) -> Result<VhostUserMsgHeader<SlaveReq>> {
366         if mem::size_of::<T>() > MAX_MSG_SIZE {
367             return Err(Error::InvalidParam);
368         }
369         self.check_state()?;
370         Ok(VhostUserMsgHeader::new(
371             req.get_code(),
372             VhostUserHeaderFlag::REPLY.bits(),
373             mem::size_of::<T>() as u32,
374         ))
375     }
376 
send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>377     fn send_ack_message(
378         &mut self,
379         req: &VhostUserMsgHeader<SlaveReq>,
380         res: &Result<u64>,
381     ) -> Result<()> {
382         if self.reply_ack_negotiated && req.is_need_reply() {
383             let hdr = self.new_reply_header::<VhostUserU64>(req)?;
384             let def_err = libc::EINVAL;
385             let val = match res {
386                 Ok(n) => *n,
387                 Err(e) => match &*e {
388                     Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
389                         Some(rawerr) => -rawerr as u64,
390                         None => -def_err as u64,
391                     },
392                     _ => -def_err as u64,
393                 },
394             };
395             let msg = VhostUserU64::new(val);
396             self.sub_sock.send_message(&hdr, &msg, None)?;
397         }
398         Ok(())
399     }
400 }
401 
402 #[cfg(unix)]
403 impl<S: VhostUserMasterReqHandler> AsRawDescriptor for MasterReqHandler<S> {
as_raw_descriptor(&self) -> RawDescriptor404     fn as_raw_descriptor(&self) -> RawDescriptor {
405         // TODO(b/221882601): figure out whether this is used for polling. If so, we need theTube's
406         // read notifier here instead.
407         self.sub_sock.as_raw_descriptor()
408     }
409 }
410 
411 #[cfg(unix)]
412 #[cfg(test)]
413 mod tests {
414     use super::*;
415     use base::{AsRawDescriptor, INVALID_DESCRIPTOR};
416     #[cfg(feature = "device")]
417     use base::{Descriptor, FromRawDescriptor};
418 
419     #[cfg(feature = "device")]
420     use crate::SlaveFsCacheReq;
421 
422     struct MockMasterReqHandler {}
423 
424     impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
425         /// Handle virtio-fs map file requests from the slave.
fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>426         fn fs_slave_map(
427             &mut self,
428             _fs: &VhostUserFSSlaveMsg,
429             _fd: &dyn AsRawDescriptor,
430         ) -> HandlerResult<u64> {
431             Ok(0)
432         }
433 
434         /// Handle virtio-fs unmap file requests from the slave.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>435         fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
436             Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
437         }
438     }
439 
440     #[test]
test_new_master_req_handler()441     fn test_new_master_req_handler() {
442         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
443         let mut handler = MasterReqHandler::new(backend).unwrap();
444 
445         assert!(handler.get_tx_raw_fd() >= 0);
446         assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
447         handler.check_state().unwrap();
448 
449         assert_eq!(handler.error, None);
450         handler.set_failed(libc::EAGAIN);
451         assert_eq!(handler.error, Some(libc::EAGAIN));
452         handler.check_state().unwrap_err();
453     }
454 
455     #[cfg(feature = "device")]
456     #[test]
test_master_slave_req_handler()457     fn test_master_slave_req_handler() {
458         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
459         let mut handler = MasterReqHandler::new(backend).unwrap();
460 
461         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
462         if fd < 0 {
463             panic!("failed to duplicated tx fd!");
464         }
465         let stream = unsafe { SystemStream::from_raw_descriptor(fd) };
466         let fs_cache = SlaveFsCacheReq::from_stream(stream);
467 
468         std::thread::spawn(move || {
469             let res = handler.handle_request().unwrap();
470             assert_eq!(res, 0);
471             handler.handle_request().unwrap_err();
472         });
473 
474         fs_cache
475             .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd))
476             .unwrap();
477         // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
478         // slave side.
479         fs_cache
480             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
481             .unwrap();
482     }
483 
484     #[cfg(feature = "device")]
485     #[test]
test_master_slave_req_handler_with_ack()486     fn test_master_slave_req_handler_with_ack() {
487         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
488         let mut handler = MasterReqHandler::new(backend).unwrap();
489         handler.set_reply_ack_flag(true);
490 
491         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
492         if fd < 0 {
493             panic!("failed to duplicated tx fd!");
494         }
495         let stream = unsafe { SystemStream::from_raw_descriptor(fd) };
496         let fs_cache = SlaveFsCacheReq::from_stream(stream);
497 
498         std::thread::spawn(move || {
499             let res = handler.handle_request().unwrap();
500             assert_eq!(res, 0);
501             handler.handle_request().unwrap_err();
502         });
503 
504         fs_cache.set_reply_ack_flag(true);
505         fs_cache
506             .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd))
507             .unwrap();
508         fs_cache
509             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
510             .unwrap_err();
511     }
512 }
513