1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::mem; 5 use std::os::unix::io::{AsRawFd, RawFd}; 6 use std::os::unix::net::UnixStream; 7 use std::sync::{Arc, Mutex}; 8 9 use super::connection::Endpoint; 10 use super::message::*; 11 use super::{Error, HandlerResult, Result}; 12 13 /// Define services provided by masters for the slave communication channel. 14 /// 15 /// The vhost-user specification defines a slave communication channel, by which slaves could 16 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided 17 /// by masters, and it's used both on the master side and slave side. 18 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy 19 /// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. 20 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler 21 /// implementing [VhostUserMasterReqHandler]. 22 /// 23 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance 24 /// for multi-threading. 25 /// 26 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 27 /// [MasterReqHandler]: struct.MasterReqHandler.html 28 /// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html 29 pub trait VhostUserMasterReqHandler { 30 /// Handle device configuration change notifications. handle_config_change(&self) -> HandlerResult<u64>31 fn handle_config_change(&self) -> HandlerResult<u64> { 32 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 33 } 34 35 /// Handle virtio-fs map file requests. fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>36 fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 37 // Safe because we have just received the rawfd from kernel. 38 unsafe { libc::close(fd) }; 39 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 40 } 41 42 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>43 fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 44 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 45 } 46 47 /// Handle virtio-fs sync file requests. fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>48 fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 49 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 50 } 51 52 /// Handle virtio-fs file IO requests. fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>53 fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 54 // Safe because we have just received the rawfd from kernel. 55 unsafe { libc::close(fd) }; 56 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 57 } 58 59 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 60 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); 61 } 62 63 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. 64 /// 65 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 66 pub trait VhostUserMasterReqHandlerMut { 67 /// Handle device configuration change notifications. handle_config_change(&mut self) -> HandlerResult<u64>68 fn handle_config_change(&mut self) -> HandlerResult<u64> { 69 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 70 } 71 72 /// Handle virtio-fs map file requests. fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>73 fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 74 // Safe because we have just received the rawfd from kernel. 75 unsafe { libc::close(fd) }; 76 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 77 } 78 79 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>80 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 81 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 82 } 83 84 /// Handle virtio-fs sync file requests. fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>85 fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 86 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 87 } 88 89 /// Handle virtio-fs file IO requests. fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>90 fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 91 // Safe because we have just received the rawfd from kernel. 92 unsafe { libc::close(fd) }; 93 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 94 } 95 96 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 97 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); 98 } 99 100 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { handle_config_change(&self) -> HandlerResult<u64>101 fn handle_config_change(&self) -> HandlerResult<u64> { 102 self.lock().unwrap().handle_config_change() 103 } 104 fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>105 fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 106 self.lock().unwrap().fs_slave_map(fs, fd) 107 } 108 fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>109 fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 110 self.lock().unwrap().fs_slave_unmap(fs) 111 } 112 fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>113 fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 114 self.lock().unwrap().fs_slave_sync(fs) 115 } 116 fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>117 fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 118 self.lock().unwrap().fs_slave_io(fs, fd) 119 } 120 } 121 122 /// Server to handle service requests from slaves from the slave communication channel. 123 /// 124 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from 125 /// slaves on the slave communication channel. It's actually a proxy invoking the registered 126 /// handler implementing [VhostUserMasterReqHandler] to do the real work. 127 /// 128 /// [MasterReqHandler]: struct.MasterReqHandler.html 129 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 130 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { 131 // underlying Unix domain socket for communication 132 sub_sock: Endpoint<SlaveReq>, 133 tx_sock: UnixStream, 134 // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. 135 reply_ack_negotiated: bool, 136 // the VirtIO backend device object 137 backend: Arc<S>, 138 // whether the endpoint has encountered any failure 139 error: Option<i32>, 140 } 141 142 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { 143 /// Create a server to handle service requests from slaves on the slave communication channel. 144 /// 145 /// This opens a pair of connected anonymous sockets to form the slave communication channel. 146 /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by 147 /// [VhostUserMaster::set_slave_request_fd()]. 148 /// 149 /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd 150 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd new(backend: Arc<S>) -> Result<Self>151 pub fn new(backend: Arc<S>) -> Result<Self> { 152 let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; 153 154 Ok(MasterReqHandler { 155 sub_sock: Endpoint::<SlaveReq>::from_stream(rx), 156 tx_sock: tx, 157 reply_ack_negotiated: false, 158 backend, 159 error: None, 160 }) 161 } 162 163 /// Get the socket fd for the slave to communication with the master. 164 /// 165 /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. 166 /// 167 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd get_tx_raw_fd(&self) -> RawFd168 pub fn get_tx_raw_fd(&self) -> RawFd { 169 self.tx_sock.as_raw_fd() 170 } 171 172 /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. 173 /// 174 /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, 175 /// the "REPLY_ACK" flag will be set in the message header for every slave to master request 176 /// message. set_reply_ack_flag(&mut self, enable: bool)177 pub fn set_reply_ack_flag(&mut self, enable: bool) { 178 self.reply_ack_negotiated = enable; 179 } 180 181 /// Mark endpoint as failed or in normal state. set_failed(&mut self, error: i32)182 pub fn set_failed(&mut self, error: i32) { 183 if error == 0 { 184 self.error = None; 185 } else { 186 self.error = Some(error); 187 } 188 } 189 190 /// Main entrance to server slave request from the slave communication channel. 191 /// 192 /// The caller needs to: 193 /// - serialize calls to this function 194 /// - decide what to do when errer happens 195 /// - optional recover from failure handle_request(&mut self) -> Result<u64>196 pub fn handle_request(&mut self) -> Result<u64> { 197 // Return error if the endpoint is already in failed state. 198 self.check_state()?; 199 200 // The underlying communication channel is a Unix domain socket in 201 // stream mode, and recvmsg() is a little tricky here. To successfully 202 // receive attached file descriptors, we need to receive messages and 203 // corresponding attached file descriptors in this way: 204 // . recv messsage header and optional attached file 205 // . validate message header 206 // . recv optional message body and payload according size field in 207 // message header 208 // . validate message body and optional payload 209 let (hdr, rfds) = self.sub_sock.recv_header()?; 210 let rfds = self.check_attached_rfds(&hdr, rfds)?; 211 let (size, buf) = match hdr.get_size() { 212 0 => (0, vec![0u8; 0]), 213 len => { 214 if len as usize > MAX_MSG_SIZE { 215 return Err(Error::InvalidMessage); 216 } 217 let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; 218 if size2 != len as usize { 219 return Err(Error::InvalidMessage); 220 } 221 (size2, rbuf) 222 } 223 }; 224 225 let res = match hdr.get_code() { 226 SlaveReq::CONFIG_CHANGE_MSG => { 227 self.check_msg_size(&hdr, size, 0)?; 228 self.backend 229 .handle_config_change() 230 .map_err(Error::ReqHandlerError) 231 } 232 SlaveReq::FS_MAP => { 233 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 234 // check_attached_rfds() has validated rfds 235 self.backend 236 .fs_slave_map(&msg, rfds.unwrap()[0]) 237 .map_err(Error::ReqHandlerError) 238 } 239 SlaveReq::FS_UNMAP => { 240 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 241 self.backend 242 .fs_slave_unmap(&msg) 243 .map_err(Error::ReqHandlerError) 244 } 245 SlaveReq::FS_SYNC => { 246 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 247 self.backend 248 .fs_slave_sync(&msg) 249 .map_err(Error::ReqHandlerError) 250 } 251 SlaveReq::FS_IO => { 252 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 253 // check_attached_rfds() has validated rfds 254 self.backend 255 .fs_slave_io(&msg, rfds.unwrap()[0]) 256 .map_err(Error::ReqHandlerError) 257 } 258 _ => Err(Error::InvalidMessage), 259 }; 260 261 self.send_ack_message(&hdr, &res)?; 262 263 res 264 } 265 check_state(&self) -> Result<()>266 fn check_state(&self) -> Result<()> { 267 match self.error { 268 Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), 269 None => Ok(()), 270 } 271 } 272 check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>273 fn check_msg_size( 274 &self, 275 hdr: &VhostUserMsgHeader<SlaveReq>, 276 size: usize, 277 expected: usize, 278 ) -> Result<()> { 279 if hdr.get_size() as usize != expected 280 || hdr.is_reply() 281 || hdr.get_version() != 0x1 282 || size != expected 283 { 284 return Err(Error::InvalidMessage); 285 } 286 Ok(()) 287 } 288 check_attached_rfds( &self, hdr: &VhostUserMsgHeader<SlaveReq>, rfds: Option<Vec<RawFd>>, ) -> Result<Option<Vec<RawFd>>>289 fn check_attached_rfds( 290 &self, 291 hdr: &VhostUserMsgHeader<SlaveReq>, 292 rfds: Option<Vec<RawFd>>, 293 ) -> Result<Option<Vec<RawFd>>> { 294 match hdr.get_code() { 295 SlaveReq::FS_MAP | SlaveReq::FS_IO => { 296 // Expect an fd set with a single fd. 297 match rfds { 298 None => Err(Error::InvalidMessage), 299 Some(fds) => { 300 if fds.len() != 1 { 301 Endpoint::<SlaveReq>::close_rfds(Some(fds)); 302 Err(Error::InvalidMessage) 303 } else { 304 Ok(Some(fds)) 305 } 306 } 307 } 308 } 309 _ => { 310 if rfds.is_some() { 311 Endpoint::<SlaveReq>::close_rfds(rfds); 312 Err(Error::InvalidMessage) 313 } else { 314 Ok(rfds) 315 } 316 } 317 } 318 } 319 extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>320 fn extract_msg_body<T: Sized + VhostUserMsgValidator>( 321 &self, 322 hdr: &VhostUserMsgHeader<SlaveReq>, 323 size: usize, 324 buf: &[u8], 325 ) -> Result<T> { 326 self.check_msg_size(hdr, size, mem::size_of::<T>())?; 327 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 328 if !msg.is_valid() { 329 return Err(Error::InvalidMessage); 330 } 331 Ok(msg) 332 } 333 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>334 fn new_reply_header<T: Sized>( 335 &self, 336 req: &VhostUserMsgHeader<SlaveReq>, 337 ) -> Result<VhostUserMsgHeader<SlaveReq>> { 338 if mem::size_of::<T>() > MAX_MSG_SIZE { 339 return Err(Error::InvalidParam); 340 } 341 self.check_state()?; 342 Ok(VhostUserMsgHeader::new( 343 req.get_code(), 344 VhostUserHeaderFlag::REPLY.bits(), 345 mem::size_of::<T>() as u32, 346 )) 347 } 348 send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>349 fn send_ack_message( 350 &mut self, 351 req: &VhostUserMsgHeader<SlaveReq>, 352 res: &Result<u64>, 353 ) -> Result<()> { 354 if self.reply_ack_negotiated && req.is_need_reply() { 355 let hdr = self.new_reply_header::<VhostUserU64>(req)?; 356 let def_err = libc::EINVAL; 357 let val = match res { 358 Ok(n) => *n, 359 Err(e) => match &*e { 360 Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { 361 Some(rawerr) => -rawerr as u64, 362 None => -def_err as u64, 363 }, 364 _ => -def_err as u64, 365 }, 366 }; 367 let msg = VhostUserU64::new(val); 368 self.sub_sock.send_message(&hdr, &msg, None)?; 369 } 370 Ok(()) 371 } 372 } 373 374 impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { as_raw_fd(&self) -> RawFd375 fn as_raw_fd(&self) -> RawFd { 376 self.sub_sock.as_raw_fd() 377 } 378 } 379 380 #[cfg(test)] 381 mod tests { 382 use super::*; 383 384 #[cfg(feature = "vhost-user-slave")] 385 use crate::vhost_user::SlaveFsCacheReq; 386 #[cfg(feature = "vhost-user-slave")] 387 use std::os::unix::io::FromRawFd; 388 389 struct MockMasterReqHandler {} 390 391 impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { 392 /// Handle virtio-fs map file requests from the slave. fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64>393 fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { 394 // Safe because we have just received the rawfd from kernel. 395 unsafe { libc::close(fd) }; 396 Ok(0) 397 } 398 399 /// Handle virtio-fs unmap file requests from the slave. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>400 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 401 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 402 } 403 } 404 405 #[test] test_new_master_req_handler()406 fn test_new_master_req_handler() { 407 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 408 let mut handler = MasterReqHandler::new(backend).unwrap(); 409 410 assert!(handler.get_tx_raw_fd() >= 0); 411 assert!(handler.as_raw_fd() >= 0); 412 handler.check_state().unwrap(); 413 414 assert_eq!(handler.error, None); 415 handler.set_failed(libc::EAGAIN); 416 assert_eq!(handler.error, Some(libc::EAGAIN)); 417 handler.check_state().unwrap_err(); 418 } 419 420 #[cfg(feature = "vhost-user-slave")] 421 #[test] test_master_slave_req_handler()422 fn test_master_slave_req_handler() { 423 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 424 let mut handler = MasterReqHandler::new(backend).unwrap(); 425 426 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 427 if fd < 0 { 428 panic!("failed to duplicated tx fd!"); 429 } 430 let stream = unsafe { UnixStream::from_raw_fd(fd) }; 431 let fs_cache = SlaveFsCacheReq::from_stream(stream); 432 433 std::thread::spawn(move || { 434 let res = handler.handle_request().unwrap(); 435 assert_eq!(res, 0); 436 handler.handle_request().unwrap_err(); 437 }); 438 439 fs_cache 440 .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) 441 .unwrap(); 442 // When REPLY_ACK has not been negotiated, the master has no way to detect failure from 443 // slave side. 444 fs_cache 445 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 446 .unwrap(); 447 } 448 449 #[cfg(feature = "vhost-user-slave")] 450 #[test] test_master_slave_req_handler_with_ack()451 fn test_master_slave_req_handler_with_ack() { 452 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 453 let mut handler = MasterReqHandler::new(backend).unwrap(); 454 handler.set_reply_ack_flag(true); 455 456 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 457 if fd < 0 { 458 panic!("failed to duplicated tx fd!"); 459 } 460 let stream = unsafe { UnixStream::from_raw_fd(fd) }; 461 let fs_cache = SlaveFsCacheReq::from_stream(stream); 462 463 std::thread::spawn(move || { 464 let res = handler.handle_request().unwrap(); 465 assert_eq!(res, 0); 466 handler.handle_request().unwrap_err(); 467 }); 468 469 fs_cache.set_reply_ack_flag(true); 470 fs_cache 471 .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) 472 .unwrap(); 473 fs_cache 474 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 475 .unwrap_err(); 476 } 477 } 478