1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 #[cfg(unix)] 5 use std::fs::File; 6 #[cfg(unix)] 7 use std::mem; 8 #[cfg(unix)] 9 use std::os::unix::io::AsRawFd; 10 use std::sync::Mutex; 11 12 #[cfg(unix)] 13 use super::connection::{socket::Endpoint as SocketEndpoint, EndpointExt}; 14 use super::message::*; 15 use super::HandlerResult; 16 #[cfg(unix)] 17 use super::{Error, Result}; 18 #[cfg(unix)] 19 use crate::SystemStream; 20 #[cfg(unix)] 21 use std::sync::Arc; 22 23 use base::AsRawDescriptor; 24 #[cfg(unix)] 25 use base::RawDescriptor; 26 27 /// Define services provided by masters for the slave communication channel. 28 /// 29 /// The vhost-user specification defines a slave communication channel, by which slaves could 30 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided 31 /// by masters, and it's used both on the master side and slave side. 32 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy 33 /// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. 34 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler 35 /// implementing [VhostUserMasterReqHandler]. 36 /// 37 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance 38 /// for multi-threading. 39 /// 40 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 41 /// [MasterReqHandler]: struct.MasterReqHandler.html 42 /// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html 43 pub trait VhostUserMasterReqHandler { 44 /// Handle device configuration change notifications. handle_config_change(&self) -> HandlerResult<u64>45 fn handle_config_change(&self) -> HandlerResult<u64> { 46 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 47 } 48 49 /// Handle virtio-fs map file requests. fs_slave_map( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>50 fn fs_slave_map( 51 &self, 52 _fs: &VhostUserFSSlaveMsg, 53 _fd: &dyn AsRawDescriptor, 54 ) -> HandlerResult<u64> { 55 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 56 } 57 58 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>59 fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 60 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 61 } 62 63 /// Handle virtio-fs sync file requests. fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>64 fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 65 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 66 } 67 68 /// Handle virtio-fs file IO requests. fs_slave_io( &self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>69 fn fs_slave_io( 70 &self, 71 _fs: &VhostUserFSSlaveMsg, 72 _fd: &dyn AsRawDescriptor, 73 ) -> HandlerResult<u64> { 74 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 75 } 76 77 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 78 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawDescriptor); 79 } 80 81 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. 82 /// 83 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 84 pub trait VhostUserMasterReqHandlerMut { 85 /// Handle device configuration change notifications. handle_config_change(&mut self) -> HandlerResult<u64>86 fn handle_config_change(&mut self) -> HandlerResult<u64> { 87 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 88 } 89 90 /// Handle virtio-fs map file requests. fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>91 fn fs_slave_map( 92 &mut self, 93 _fs: &VhostUserFSSlaveMsg, 94 _fd: &dyn AsRawDescriptor, 95 ) -> HandlerResult<u64> { 96 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 97 } 98 99 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>100 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 101 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 102 } 103 104 /// Handle virtio-fs sync file requests. fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>105 fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 106 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 107 } 108 109 /// Handle virtio-fs file IO requests. fs_slave_io( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>110 fn fs_slave_io( 111 &mut self, 112 _fs: &VhostUserFSSlaveMsg, 113 _fd: &dyn AsRawDescriptor, 114 ) -> HandlerResult<u64> { 115 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 116 } 117 118 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 119 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawDescriptor); 120 } 121 122 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { handle_config_change(&self) -> HandlerResult<u64>123 fn handle_config_change(&self) -> HandlerResult<u64> { 124 self.lock().unwrap().handle_config_change() 125 } 126 fs_slave_map( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>127 fn fs_slave_map( 128 &self, 129 fs: &VhostUserFSSlaveMsg, 130 fd: &dyn AsRawDescriptor, 131 ) -> HandlerResult<u64> { 132 self.lock().unwrap().fs_slave_map(fs, fd) 133 } 134 fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>135 fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 136 self.lock().unwrap().fs_slave_unmap(fs) 137 } 138 fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>139 fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 140 self.lock().unwrap().fs_slave_sync(fs) 141 } 142 fs_slave_io( &self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>143 fn fs_slave_io( 144 &self, 145 fs: &VhostUserFSSlaveMsg, 146 fd: &dyn AsRawDescriptor, 147 ) -> HandlerResult<u64> { 148 self.lock().unwrap().fs_slave_io(fs, fd) 149 } 150 } 151 152 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from 153 /// slaves on the slave communication channel. It's actually a proxy invoking the registered 154 /// handler implementing [VhostUserMasterReqHandler] to do the real work. 155 /// 156 /// [MasterReqHandler]: struct.MasterReqHandler.html 157 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 158 /// 159 /// TODO(b/221882601): we can write a version of this for Windows by switching the socket for a Tube. 160 /// The interfaces would need to change so that we fetch a full Tube (which is 2 rds on Windows) 161 /// and send it to the device backend (slave) as a message on the master -> slave channel. 162 /// (Currently the interface only supports sending a single rd.) 163 /// 164 /// Note that handling requests from slaves is not needed for the initial devices we plan to 165 /// support. 166 /// 167 /// Server to handle service requests from slaves from the slave communication channel. 168 #[cfg(unix)] 169 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { 170 // underlying Unix domain socket for communication 171 sub_sock: SocketEndpoint<SlaveReq>, 172 tx_sock: SystemStream, 173 // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. 174 reply_ack_negotiated: bool, 175 // the VirtIO backend device object 176 backend: Arc<S>, 177 // whether the endpoint has encountered any failure 178 error: Option<i32>, 179 } 180 181 #[cfg(unix)] 182 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { 183 /// Create a server to handle service requests from slaves on the slave communication channel. 184 /// 185 /// This opens a pair of connected anonymous sockets to form the slave communication channel. 186 /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by 187 /// [VhostUserMaster::set_slave_request_fd()]. 188 /// 189 /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd 190 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd new(backend: Arc<S>) -> Result<Self>191 pub fn new(backend: Arc<S>) -> Result<Self> { 192 let (tx, rx) = SystemStream::pair().map_err(Error::SocketError)?; 193 194 Ok(MasterReqHandler { 195 sub_sock: SocketEndpoint::<SlaveReq>::from(rx), 196 tx_sock: tx, 197 reply_ack_negotiated: false, 198 backend, 199 error: None, 200 }) 201 } 202 203 /// Get the socket fd for the slave to communication with the master. 204 /// 205 /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. 206 /// 207 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd get_tx_raw_fd(&self) -> RawDescriptor208 pub fn get_tx_raw_fd(&self) -> RawDescriptor { 209 self.tx_sock.as_raw_fd() 210 } 211 212 /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. 213 /// 214 /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, 215 /// the "REPLY_ACK" flag will be set in the message header for every slave to master request 216 /// message. set_reply_ack_flag(&mut self, enable: bool)217 pub fn set_reply_ack_flag(&mut self, enable: bool) { 218 self.reply_ack_negotiated = enable; 219 } 220 221 /// Mark endpoint as failed or in normal state. set_failed(&mut self, error: i32)222 pub fn set_failed(&mut self, error: i32) { 223 if error == 0 { 224 self.error = None; 225 } else { 226 self.error = Some(error); 227 } 228 } 229 230 /// Main entrance to server slave request from the slave communication channel. 231 /// 232 /// The caller needs to: 233 /// - serialize calls to this function 234 /// - decide what to do when errer happens 235 /// - optional recover from failure handle_request(&mut self) -> Result<u64>236 pub fn handle_request(&mut self) -> Result<u64> { 237 // Return error if the endpoint is already in failed state. 238 self.check_state()?; 239 240 // The underlying communication channel is a Unix domain socket in 241 // stream mode, and recvmsg() is a little tricky here. To successfully 242 // receive attached file descriptors, we need to receive messages and 243 // corresponding attached file descriptors in this way: 244 // . recv messsage header and optional attached file 245 // . validate message header 246 // . recv optional message body and payload according size field in 247 // message header 248 // . validate message body and optional payload 249 let (hdr, files) = self.sub_sock.recv_header()?; 250 self.check_attached_files(&hdr, &files)?; 251 let buf = match hdr.get_size() { 252 0 => vec![0u8; 0], 253 len => { 254 if len as usize > MAX_MSG_SIZE { 255 return Err(Error::InvalidMessage); 256 } 257 let rbuf = self.sub_sock.recv_data(len as usize)?; 258 if rbuf.len() != len as usize { 259 return Err(Error::InvalidMessage); 260 } 261 rbuf 262 } 263 }; 264 let size = buf.len(); 265 266 let res = match hdr.get_code() { 267 SlaveReq::CONFIG_CHANGE_MSG => { 268 self.check_msg_size(&hdr, size, 0)?; 269 self.backend 270 .handle_config_change() 271 .map_err(Error::ReqHandlerError) 272 } 273 SlaveReq::FS_MAP => { 274 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 275 // check_attached_files() has validated files 276 self.backend 277 .fs_slave_map(&msg, &files.unwrap()[0]) 278 .map_err(Error::ReqHandlerError) 279 } 280 SlaveReq::FS_UNMAP => { 281 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 282 self.backend 283 .fs_slave_unmap(&msg) 284 .map_err(Error::ReqHandlerError) 285 } 286 SlaveReq::FS_SYNC => { 287 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 288 self.backend 289 .fs_slave_sync(&msg) 290 .map_err(Error::ReqHandlerError) 291 } 292 SlaveReq::FS_IO => { 293 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 294 // check_attached_files() has validated files 295 self.backend 296 .fs_slave_io(&msg, &files.unwrap()[0]) 297 .map_err(Error::ReqHandlerError) 298 } 299 _ => Err(Error::InvalidMessage), 300 }; 301 302 self.send_ack_message(&hdr, &res)?; 303 304 res 305 } 306 check_state(&self) -> Result<()>307 fn check_state(&self) -> Result<()> { 308 match self.error { 309 Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), 310 None => Ok(()), 311 } 312 } 313 check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>314 fn check_msg_size( 315 &self, 316 hdr: &VhostUserMsgHeader<SlaveReq>, 317 size: usize, 318 expected: usize, 319 ) -> Result<()> { 320 if hdr.get_size() as usize != expected 321 || hdr.is_reply() 322 || hdr.get_version() != 0x1 323 || size != expected 324 { 325 return Err(Error::InvalidMessage); 326 } 327 Ok(()) 328 } 329 check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, files: &Option<Vec<File>>, ) -> Result<()>330 fn check_attached_files( 331 &self, 332 hdr: &VhostUserMsgHeader<SlaveReq>, 333 files: &Option<Vec<File>>, 334 ) -> Result<()> { 335 match hdr.get_code() { 336 SlaveReq::FS_MAP | SlaveReq::FS_IO => { 337 // Expect a single file is passed. 338 match files { 339 Some(files) if files.len() == 1 => Ok(()), 340 _ => Err(Error::InvalidMessage), 341 } 342 } 343 _ if files.is_some() => Err(Error::InvalidMessage), 344 _ => Ok(()), 345 } 346 } 347 extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>348 fn extract_msg_body<T: Sized + VhostUserMsgValidator>( 349 &self, 350 hdr: &VhostUserMsgHeader<SlaveReq>, 351 size: usize, 352 buf: &[u8], 353 ) -> Result<T> { 354 self.check_msg_size(hdr, size, mem::size_of::<T>())?; 355 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 356 if !msg.is_valid() { 357 return Err(Error::InvalidMessage); 358 } 359 Ok(msg) 360 } 361 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>362 fn new_reply_header<T: Sized>( 363 &self, 364 req: &VhostUserMsgHeader<SlaveReq>, 365 ) -> Result<VhostUserMsgHeader<SlaveReq>> { 366 if mem::size_of::<T>() > MAX_MSG_SIZE { 367 return Err(Error::InvalidParam); 368 } 369 self.check_state()?; 370 Ok(VhostUserMsgHeader::new( 371 req.get_code(), 372 VhostUserHeaderFlag::REPLY.bits(), 373 mem::size_of::<T>() as u32, 374 )) 375 } 376 send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>377 fn send_ack_message( 378 &mut self, 379 req: &VhostUserMsgHeader<SlaveReq>, 380 res: &Result<u64>, 381 ) -> Result<()> { 382 if self.reply_ack_negotiated && req.is_need_reply() { 383 let hdr = self.new_reply_header::<VhostUserU64>(req)?; 384 let def_err = libc::EINVAL; 385 let val = match res { 386 Ok(n) => *n, 387 Err(e) => match &*e { 388 Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { 389 Some(rawerr) => -rawerr as u64, 390 None => -def_err as u64, 391 }, 392 _ => -def_err as u64, 393 }, 394 }; 395 let msg = VhostUserU64::new(val); 396 self.sub_sock.send_message(&hdr, &msg, None)?; 397 } 398 Ok(()) 399 } 400 } 401 402 #[cfg(unix)] 403 impl<S: VhostUserMasterReqHandler> AsRawDescriptor for MasterReqHandler<S> { as_raw_descriptor(&self) -> RawDescriptor404 fn as_raw_descriptor(&self) -> RawDescriptor { 405 // TODO(b/221882601): figure out whether this is used for polling. If so, we need theTube's 406 // read notifier here instead. 407 self.sub_sock.as_raw_descriptor() 408 } 409 } 410 411 #[cfg(unix)] 412 #[cfg(test)] 413 mod tests { 414 use super::*; 415 use base::{AsRawDescriptor, INVALID_DESCRIPTOR}; 416 #[cfg(feature = "device")] 417 use base::{Descriptor, FromRawDescriptor}; 418 419 #[cfg(feature = "device")] 420 use crate::SlaveFsCacheReq; 421 422 struct MockMasterReqHandler {} 423 424 impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { 425 /// Handle virtio-fs map file requests from the slave. fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>426 fn fs_slave_map( 427 &mut self, 428 _fs: &VhostUserFSSlaveMsg, 429 _fd: &dyn AsRawDescriptor, 430 ) -> HandlerResult<u64> { 431 Ok(0) 432 } 433 434 /// Handle virtio-fs unmap file requests from the slave. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>435 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 436 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 437 } 438 } 439 440 #[test] test_new_master_req_handler()441 fn test_new_master_req_handler() { 442 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 443 let mut handler = MasterReqHandler::new(backend).unwrap(); 444 445 assert!(handler.get_tx_raw_fd() >= 0); 446 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); 447 handler.check_state().unwrap(); 448 449 assert_eq!(handler.error, None); 450 handler.set_failed(libc::EAGAIN); 451 assert_eq!(handler.error, Some(libc::EAGAIN)); 452 handler.check_state().unwrap_err(); 453 } 454 455 #[cfg(feature = "device")] 456 #[test] test_master_slave_req_handler()457 fn test_master_slave_req_handler() { 458 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 459 let mut handler = MasterReqHandler::new(backend).unwrap(); 460 461 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 462 if fd < 0 { 463 panic!("failed to duplicated tx fd!"); 464 } 465 let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; 466 let fs_cache = SlaveFsCacheReq::from_stream(stream); 467 468 std::thread::spawn(move || { 469 let res = handler.handle_request().unwrap(); 470 assert_eq!(res, 0); 471 handler.handle_request().unwrap_err(); 472 }); 473 474 fs_cache 475 .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) 476 .unwrap(); 477 // When REPLY_ACK has not been negotiated, the master has no way to detect failure from 478 // slave side. 479 fs_cache 480 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 481 .unwrap(); 482 } 483 484 #[cfg(feature = "device")] 485 #[test] test_master_slave_req_handler_with_ack()486 fn test_master_slave_req_handler_with_ack() { 487 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 488 let mut handler = MasterReqHandler::new(backend).unwrap(); 489 handler.set_reply_ack_flag(true); 490 491 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 492 if fd < 0 { 493 panic!("failed to duplicated tx fd!"); 494 } 495 let stream = unsafe { SystemStream::from_raw_descriptor(fd) }; 496 let fs_cache = SlaveFsCacheReq::from_stream(stream); 497 498 std::thread::spawn(move || { 499 let res = handler.handle_request().unwrap(); 500 assert_eq!(res, 0); 501 handler.handle_request().unwrap_err(); 502 }); 503 504 fs_cache.set_reply_ack_flag(true); 505 fs_cache 506 .fs_slave_map(&VhostUserFSSlaveMsg::default(), &Descriptor(fd)) 507 .unwrap(); 508 fs_cache 509 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 510 .unwrap_err(); 511 } 512 } 513