1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 7 use base::error; 8 use base::AsRawDescriptor; 9 use base::RawDescriptor; 10 use zerocopy::AsBytes; 11 use zerocopy::FromBytes; 12 use zerocopy::Ref; 13 14 use crate::into_single_file; 15 use crate::message::*; 16 use crate::to_system_stream; 17 use crate::BackendReq; 18 use crate::Connection; 19 use crate::Error; 20 use crate::FrontendReq; 21 use crate::Result; 22 use crate::SystemStream; 23 24 /// Trait for vhost-user backends. 25 /// 26 /// Each method corresponds to a vhost-user protocol method. See the specification for details. 27 #[allow(missing_docs)] 28 pub trait Backend { set_owner(&mut self) -> Result<()>29 fn set_owner(&mut self) -> Result<()>; reset_owner(&mut self) -> Result<()>30 fn reset_owner(&mut self) -> Result<()>; get_features(&mut self) -> Result<u64>31 fn get_features(&mut self) -> Result<u64>; set_features(&mut self, features: u64) -> Result<()>32 fn set_features(&mut self, features: u64) -> Result<()>; set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>33 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&mut self, index: u32, num: u32) -> Result<()>34 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>35 fn set_vring_addr( 36 &mut self, 37 index: u32, 38 flags: VhostUserVringAddrFlags, 39 descriptor: u64, 40 used: u64, 41 available: u64, 42 log: u64, 43 ) -> Result<()>; 44 // TODO: b/331466964 - Argument type is wrong for packed queues. set_vring_base(&mut self, index: u32, base: u32) -> Result<()>45 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; 46 // TODO: b/331466964 - Return type is wrong for packed queues. get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>47 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>48 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>49 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>50 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; 51 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>52 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&mut self, features: u64) -> Result<()>53 fn set_protocol_features(&mut self, features: u64) -> Result<()>; get_queue_num(&mut self) -> Result<u64>54 fn get_queue_num(&mut self) -> Result<u64>; set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>55 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>56 fn get_config( 57 &mut self, 58 offset: u32, 59 size: u32, 60 flags: VhostUserConfigFlags, 61 ) -> Result<Vec<u8>>; set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>62 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>)63 fn set_backend_req_fd(&mut self, _vu_req: Connection<BackendReq>) {} get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>64 fn get_inflight_fd( 65 &mut self, 66 inflight: &VhostUserInflight, 67 ) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>68 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&mut self) -> Result<u64>69 fn get_max_mem_slots(&mut self) -> Result<u64>; add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>70 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>71 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>72 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>; 73 /// Request the device to sleep by stopping their workers. This should NOT be called if the 74 /// device is already asleep. sleep(&mut self) -> Result<()>75 fn sleep(&mut self) -> Result<()>; 76 /// Request the device to wake up by starting up their workers. This should NOT be called if the 77 /// device is already awake. wake(&mut self) -> Result<()>78 fn wake(&mut self) -> Result<()>; snapshot(&mut self) -> Result<Vec<u8>>79 fn snapshot(&mut self) -> Result<Vec<u8>>; restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>80 fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>; 81 } 82 83 impl<T> Backend for T 84 where 85 T: AsMut<dyn Backend>, 86 { set_owner(&mut self) -> Result<()>87 fn set_owner(&mut self) -> Result<()> { 88 self.as_mut().set_owner() 89 } 90 reset_owner(&mut self) -> Result<()>91 fn reset_owner(&mut self) -> Result<()> { 92 self.as_mut().reset_owner() 93 } 94 get_features(&mut self) -> Result<u64>95 fn get_features(&mut self) -> Result<u64> { 96 self.as_mut().get_features() 97 } 98 set_features(&mut self, features: u64) -> Result<()>99 fn set_features(&mut self, features: u64) -> Result<()> { 100 self.as_mut().set_features(features) 101 } 102 set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>103 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 104 self.as_mut().set_mem_table(ctx, files) 105 } 106 set_vring_num(&mut self, index: u32, num: u32) -> Result<()>107 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { 108 self.as_mut().set_vring_num(index, num) 109 } 110 set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>111 fn set_vring_addr( 112 &mut self, 113 index: u32, 114 flags: VhostUserVringAddrFlags, 115 descriptor: u64, 116 used: u64, 117 available: u64, 118 log: u64, 119 ) -> Result<()> { 120 self.as_mut() 121 .set_vring_addr(index, flags, descriptor, used, available, log) 122 } 123 set_vring_base(&mut self, index: u32, base: u32) -> Result<()>124 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { 125 self.as_mut().set_vring_base(index, base) 126 } 127 get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>128 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { 129 self.as_mut().get_vring_base(index) 130 } 131 set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>132 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> { 133 self.as_mut().set_vring_kick(index, fd) 134 } 135 set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>136 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> { 137 self.as_mut().set_vring_call(index, fd) 138 } 139 set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>140 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> { 141 self.as_mut().set_vring_err(index, fd) 142 } 143 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>144 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { 145 self.as_mut().get_protocol_features() 146 } 147 set_protocol_features(&mut self, features: u64) -> Result<()>148 fn set_protocol_features(&mut self, features: u64) -> Result<()> { 149 self.as_mut().set_protocol_features(features) 150 } 151 get_queue_num(&mut self) -> Result<u64>152 fn get_queue_num(&mut self) -> Result<u64> { 153 self.as_mut().get_queue_num() 154 } 155 set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>156 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { 157 self.as_mut().set_vring_enable(index, enable) 158 } 159 get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>160 fn get_config( 161 &mut self, 162 offset: u32, 163 size: u32, 164 flags: VhostUserConfigFlags, 165 ) -> Result<Vec<u8>> { 166 self.as_mut().get_config(offset, size, flags) 167 } 168 set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>169 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 170 self.as_mut().set_config(offset, buf, flags) 171 } 172 set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>)173 fn set_backend_req_fd(&mut self, vu_req: Connection<BackendReq>) { 174 self.as_mut().set_backend_req_fd(vu_req) 175 } 176 get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>177 fn get_inflight_fd( 178 &mut self, 179 inflight: &VhostUserInflight, 180 ) -> Result<(VhostUserInflight, File)> { 181 self.as_mut().get_inflight_fd(inflight) 182 } 183 set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>184 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()> { 185 self.as_mut().set_inflight_fd(inflight, file) 186 } 187 get_max_mem_slots(&mut self) -> Result<u64>188 fn get_max_mem_slots(&mut self) -> Result<u64> { 189 self.as_mut().get_max_mem_slots() 190 } 191 add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>192 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 193 self.as_mut().add_mem_region(region, fd) 194 } 195 remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>196 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 197 self.as_mut().remove_mem_region(region) 198 } 199 get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>200 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>> { 201 self.as_mut().get_shared_memory_regions() 202 } 203 sleep(&mut self) -> Result<()>204 fn sleep(&mut self) -> Result<()> { 205 self.as_mut().sleep() 206 } 207 wake(&mut self) -> Result<()>208 fn wake(&mut self) -> Result<()> { 209 self.as_mut().wake() 210 } 211 snapshot(&mut self) -> Result<Vec<u8>>212 fn snapshot(&mut self) -> Result<Vec<u8>> { 213 self.as_mut().snapshot() 214 } 215 restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()>216 fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> Result<()> { 217 self.as_mut().restore(data_bytes, queue_evts) 218 } 219 } 220 221 /// Handles requests from a vhost-user connection by dispatching them to [[Backend]] methods. 222 pub struct BackendServer<S: Backend> { 223 /// Underlying connection for communication. 224 connection: Connection<FrontendReq>, 225 // the vhost-user backend device object 226 backend: S, 227 228 virtio_features: u64, 229 acked_virtio_features: u64, 230 protocol_features: VhostUserProtocolFeatures, 231 acked_protocol_features: u64, 232 233 /// Sending ack for messages without payload. 234 reply_ack_enabled: bool, 235 } 236 237 impl<S: Backend> BackendServer<S> { 238 /// Create a backend server from a connected socket. from_stream(socket: SystemStream, backend: S) -> Self239 pub fn from_stream(socket: SystemStream, backend: S) -> Self { 240 Self::new(Connection::from(socket), backend) 241 } 242 } 243 244 impl<S: Backend> AsRef<S> for BackendServer<S> { as_ref(&self) -> &S245 fn as_ref(&self) -> &S { 246 &self.backend 247 } 248 } 249 250 impl<S: Backend> BackendServer<S> { new(connection: Connection<FrontendReq>, backend: S) -> Self251 pub fn new(connection: Connection<FrontendReq>, backend: S) -> Self { 252 BackendServer { 253 connection, 254 backend, 255 virtio_features: 0, 256 acked_virtio_features: 0, 257 protocol_features: VhostUserProtocolFeatures::empty(), 258 acked_protocol_features: 0, 259 reply_ack_enabled: false, 260 } 261 } 262 263 /// Receives and validates a vhost-user message header and optional files. 264 /// 265 /// Since the length of vhost-user messages are different among message types, regular 266 /// vhost-user messages are sent via an underlying communication channel in stream mode. 267 /// (e.g. `SOCK_STREAM` in UNIX) 268 /// So, the logic of receiving and handling a message consists of the following steps: 269 /// 270 /// 1. Receives a message header and optional attached file. 271 /// 2. Validates the message header. 272 /// 3. Check if optional payloads is expected. 273 /// 4. Wait for the optional payloads. 274 /// 5. Receives optional payloads. 275 /// 6. Processes the message. 276 /// 277 /// This method [`BackendServer::recv_header()`] is in charge of the step (1) and (2), 278 /// [`BackendServer::needs_wait_for_payload()`] is (3), and 279 /// [`BackendServer::process_message()`] is (5) and (6). We need to have the three method 280 /// separately for multi-platform supports; [`BackendServer::recv_header()`] and 281 /// [`BackendServer::process_message()`] need to be separated because the way of waiting for 282 /// incoming messages differs between Unix and Windows so it's the caller's responsibility to 283 /// wait before [`BackendServer::process_message()`]. 284 /// 285 /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this 286 /// case, a message header and its body are sent together so the step (4) is skipped. We handle 287 /// this case in [`BackendServer::needs_wait_for_payload()`]. 288 /// 289 /// The following pseudo code describes how a caller should process incoming vhost-user 290 /// messages: 291 /// ```ignore 292 /// loop { 293 /// // block until a message header comes. 294 /// // The actual code differs, depending on platforms. 295 /// connection.wait_readable().unwrap(); 296 /// 297 /// // (1) and (2) 298 /// let (hdr, files) = backend_server.recv_header(); 299 /// 300 /// // (3) 301 /// if backend_server.needs_wait_for_payload(&hdr) { 302 /// // (4) block until a payload comes if needed. 303 /// connection.wait_readable().unwrap(); 304 /// } 305 /// 306 /// // (5) and (6) 307 /// backend_server.process_message(&hdr, &files).unwrap(); 308 /// } 309 /// ``` recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)>310 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<FrontendReq>, Vec<File>)> { 311 // The underlying communication channel is a Unix domain socket in 312 // stream mode, and recvmsg() is a little tricky here. To successfully 313 // receive attached file descriptors, we need to receive messages and 314 // corresponding attached file descriptors in this way: 315 // . recv messsage header and optional attached file 316 // . validate message header 317 // . recv optional message body and payload according size field in 318 // message header 319 // . validate message body and optional payload 320 let (hdr, files) = match self.connection.recv_header() { 321 Ok((hdr, files)) => (hdr, files), 322 Err(Error::Disconnect) => { 323 // If the client closed the connection before sending a header, this should be 324 // handled as a legal exit. 325 return Err(Error::ClientExit); 326 } 327 Err(e) => { 328 return Err(e); 329 } 330 }; 331 332 self.check_attached_files(&hdr, &files)?; 333 334 Ok((hdr, files)) 335 } 336 337 /// Returns whether the caller needs to wait for the incoming message before calling 338 /// [`BackendServer::process_message`]. 339 /// 340 /// See [`BackendServer::recv_header`]'s doc comment for the usage. needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool341 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> bool { 342 // Since the vhost-user protocol uses stream mode, we need to wait until an additional 343 // payload is available if exists. 344 hdr.get_size() != 0 345 } 346 347 /// Main entrance to request from the communication channel. 348 /// 349 /// Receive and handle one incoming request message from the frontend. 350 /// See [`BackendServer::recv_header`]'s doc comment for the usage. 351 /// 352 /// # Return: 353 /// * - `Ok(())`: one request was successfully handled. 354 /// * - `Err(ClientExit)`: the frontend closed the connection properly. This isn't an actual 355 /// failure. 356 /// * - `Err(Disconnect)`: the connection was closed unexpectedly. 357 /// * - `Err(InvalidMessage)`: the vmm sent a illegal message. 358 /// * - other errors: failed to handle a request. process_message( &mut self, hdr: VhostUserMsgHeader<FrontendReq>, files: Vec<File>, ) -> Result<()>359 pub fn process_message( 360 &mut self, 361 hdr: VhostUserMsgHeader<FrontendReq>, 362 files: Vec<File>, 363 ) -> Result<()> { 364 let buf = self.connection.recv_body_bytes(&hdr)?; 365 let size = buf.len(); 366 367 match hdr.get_code() { 368 Ok(FrontendReq::SET_OWNER) => { 369 self.check_request_size(&hdr, size, 0)?; 370 let res = self.backend.set_owner(); 371 self.send_ack_message(&hdr, res.is_ok())?; 372 res?; 373 } 374 Ok(FrontendReq::RESET_OWNER) => { 375 self.check_request_size(&hdr, size, 0)?; 376 let res = self.backend.reset_owner(); 377 self.send_ack_message(&hdr, res.is_ok())?; 378 res?; 379 } 380 Ok(FrontendReq::GET_FEATURES) => { 381 self.check_request_size(&hdr, size, 0)?; 382 let mut features = self.backend.get_features()?; 383 384 // Don't advertise packed queues even if the device does. We don't handle them 385 // properly yet at the protocol layer. 386 // TODO: b/331466964 - Remove once support is added. 387 features &= !(1 << VIRTIO_F_RING_PACKED); 388 389 let msg = VhostUserU64::new(features); 390 self.send_reply_message(&hdr, &msg)?; 391 self.virtio_features = features; 392 self.update_reply_ack_flag(); 393 } 394 Ok(FrontendReq::SET_FEATURES) => { 395 let mut msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 396 397 // Don't allow packed queues even if the device does. We don't handle them 398 // properly yet at the protocol layer. 399 // TODO: b/331466964 - Remove once support is added. 400 msg.value &= !(1 << VIRTIO_F_RING_PACKED); 401 402 let res = self.backend.set_features(msg.value); 403 self.acked_virtio_features = msg.value; 404 self.update_reply_ack_flag(); 405 self.send_ack_message(&hdr, res.is_ok())?; 406 res?; 407 } 408 Ok(FrontendReq::SET_MEM_TABLE) => { 409 let res = self.set_mem_table(&hdr, size, &buf, files); 410 self.send_ack_message(&hdr, res.is_ok())?; 411 res?; 412 } 413 Ok(FrontendReq::SET_VRING_NUM) => { 414 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 415 let res = self.backend.set_vring_num(msg.index, msg.num); 416 self.send_ack_message(&hdr, res.is_ok())?; 417 res?; 418 } 419 Ok(FrontendReq::SET_VRING_ADDR) => { 420 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; 421 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { 422 Some(val) => val, 423 None => return Err(Error::InvalidMessage), 424 }; 425 let res = self.backend.set_vring_addr( 426 msg.index, 427 flags, 428 msg.descriptor, 429 msg.used, 430 msg.available, 431 msg.log, 432 ); 433 self.send_ack_message(&hdr, res.is_ok())?; 434 res?; 435 } 436 Ok(FrontendReq::SET_VRING_BASE) => { 437 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 438 let res = self.backend.set_vring_base(msg.index, msg.num); 439 self.send_ack_message(&hdr, res.is_ok())?; 440 res?; 441 } 442 Ok(FrontendReq::GET_VRING_BASE) => { 443 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 444 let reply = self.backend.get_vring_base(msg.index)?; 445 self.send_reply_message(&hdr, &reply)?; 446 } 447 Ok(FrontendReq::SET_VRING_CALL) => { 448 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 449 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 450 let res = self.backend.set_vring_call(index, file); 451 self.send_ack_message(&hdr, res.is_ok())?; 452 res?; 453 } 454 Ok(FrontendReq::SET_VRING_KICK) => { 455 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 456 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 457 let res = self.backend.set_vring_kick(index, file); 458 self.send_ack_message(&hdr, res.is_ok())?; 459 res?; 460 } 461 Ok(FrontendReq::SET_VRING_ERR) => { 462 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 463 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 464 let res = self.backend.set_vring_err(index, file); 465 self.send_ack_message(&hdr, res.is_ok())?; 466 res?; 467 } 468 Ok(FrontendReq::GET_PROTOCOL_FEATURES) => { 469 self.check_request_size(&hdr, size, 0)?; 470 let features = self.backend.get_protocol_features()?; 471 let msg = VhostUserU64::new(features.bits()); 472 self.send_reply_message(&hdr, &msg)?; 473 self.protocol_features = features; 474 self.update_reply_ack_flag(); 475 } 476 Ok(FrontendReq::SET_PROTOCOL_FEATURES) => { 477 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 478 let res = self.backend.set_protocol_features(msg.value); 479 self.acked_protocol_features = msg.value; 480 self.update_reply_ack_flag(); 481 self.send_ack_message(&hdr, res.is_ok())?; 482 res?; 483 } 484 Ok(FrontendReq::GET_QUEUE_NUM) => { 485 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { 486 return Err(Error::InvalidOperation); 487 } 488 self.check_request_size(&hdr, size, 0)?; 489 let num = self.backend.get_queue_num()?; 490 let msg = VhostUserU64::new(num); 491 self.send_reply_message(&hdr, &msg)?; 492 } 493 Ok(FrontendReq::SET_VRING_ENABLE) => { 494 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 495 if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 { 496 return Err(Error::InvalidOperation); 497 } 498 let enable = match msg.num { 499 1 => true, 500 0 => false, 501 _ => return Err(Error::InvalidParam), 502 }; 503 504 let res = self.backend.set_vring_enable(msg.index, enable); 505 self.send_ack_message(&hdr, res.is_ok())?; 506 res?; 507 } 508 Ok(FrontendReq::GET_CONFIG) => { 509 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 510 return Err(Error::InvalidOperation); 511 } 512 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 513 self.get_config(&hdr, &buf)?; 514 } 515 Ok(FrontendReq::SET_CONFIG) => { 516 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 517 return Err(Error::InvalidOperation); 518 } 519 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 520 let res = self.set_config(&buf); 521 self.send_ack_message(&hdr, res.is_ok())?; 522 res?; 523 } 524 Ok(FrontendReq::SET_BACKEND_REQ_FD) => { 525 if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0 526 { 527 return Err(Error::InvalidOperation); 528 } 529 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 530 let res = self.set_backend_req_fd(files); 531 self.send_ack_message(&hdr, res.is_ok())?; 532 res?; 533 } 534 Ok(FrontendReq::GET_INFLIGHT_FD) => { 535 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 536 == 0 537 { 538 return Err(Error::InvalidOperation); 539 } 540 541 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 542 let (inflight, file) = self.backend.get_inflight_fd(&msg)?; 543 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; 544 self.connection.send_message( 545 &reply_hdr, 546 &inflight, 547 Some(&[file.as_raw_descriptor()]), 548 )?; 549 } 550 Ok(FrontendReq::SET_INFLIGHT_FD) => { 551 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 552 == 0 553 { 554 return Err(Error::InvalidOperation); 555 } 556 let file = into_single_file(files).ok_or(Error::IncorrectFds)?; 557 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 558 let res = self.backend.set_inflight_fd(&msg, file); 559 self.send_ack_message(&hdr, res.is_ok())?; 560 res?; 561 } 562 Ok(FrontendReq::GET_MAX_MEM_SLOTS) => { 563 if self.acked_protocol_features 564 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 565 == 0 566 { 567 return Err(Error::InvalidOperation); 568 } 569 self.check_request_size(&hdr, size, 0)?; 570 let num = self.backend.get_max_mem_slots()?; 571 let msg = VhostUserU64::new(num); 572 self.send_reply_message(&hdr, &msg)?; 573 } 574 Ok(FrontendReq::ADD_MEM_REG) => { 575 if self.acked_protocol_features 576 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 577 == 0 578 { 579 return Err(Error::InvalidOperation); 580 } 581 let file = into_single_file(files).ok_or(Error::InvalidParam)?; 582 let msg = 583 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 584 let res = self.backend.add_mem_region(&msg, file); 585 self.send_ack_message(&hdr, res.is_ok())?; 586 res?; 587 } 588 Ok(FrontendReq::REM_MEM_REG) => { 589 if self.acked_protocol_features 590 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 591 == 0 592 { 593 return Err(Error::InvalidOperation); 594 } 595 596 let msg = 597 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 598 let res = self.backend.remove_mem_region(&msg); 599 self.send_ack_message(&hdr, res.is_ok())?; 600 res?; 601 } 602 Ok(FrontendReq::GET_SHARED_MEMORY_REGIONS) => { 603 let regions = self.backend.get_shared_memory_regions()?; 604 let mut buf = Vec::new(); 605 let msg = VhostUserU64::new(regions.len() as u64); 606 for r in regions { 607 buf.extend_from_slice(r.as_bytes()) 608 } 609 self.send_reply_with_payload(&hdr, &msg, buf.as_slice())?; 610 } 611 Ok(FrontendReq::SLEEP) => { 612 let res = self.backend.sleep(); 613 let msg = VhostUserSuccess::new(res.is_ok()); 614 self.send_reply_message(&hdr, &msg)?; 615 } 616 Ok(FrontendReq::WAKE) => { 617 let res = self.backend.wake(); 618 let msg = VhostUserSuccess::new(res.is_ok()); 619 self.send_reply_message(&hdr, &msg)?; 620 } 621 Ok(FrontendReq::SNAPSHOT) => { 622 let (success_msg, payload) = match self.backend.snapshot() { 623 Ok(snapshot_payload) => (VhostUserSuccess::new(true), snapshot_payload), 624 Err(e) => { 625 error!("Failed to snapshot: {}", e); 626 (VhostUserSuccess::new(false), Vec::new()) 627 } 628 }; 629 self.send_reply_with_payload(&hdr, &success_msg, payload.as_slice())?; 630 } 631 Ok(FrontendReq::RESTORE) => { 632 let res = self.backend.restore(buf.as_slice(), files); 633 let msg = VhostUserSuccess::new(res.is_ok()); 634 self.send_reply_message(&hdr, &msg)?; 635 } 636 _ => { 637 return Err(Error::InvalidMessage); 638 } 639 } 640 Ok(()) 641 } 642 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<FrontendReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<FrontendReq>>643 fn new_reply_header<T: Sized>( 644 &self, 645 req: &VhostUserMsgHeader<FrontendReq>, 646 payload_size: usize, 647 ) -> Result<VhostUserMsgHeader<FrontendReq>> { 648 Ok(VhostUserMsgHeader::new( 649 req.get_code().map_err(|_| Error::InvalidMessage)?, 650 VhostUserHeaderFlag::REPLY.bits(), 651 (mem::size_of::<T>() 652 .checked_add(payload_size) 653 .ok_or(Error::OversizedMsg)?) 654 .try_into() 655 .map_err(Error::InvalidCastToInt)?, 656 )) 657 } 658 659 /// Sends reply back to Vhost frontend in response to a message. send_ack_message( &mut self, req: &VhostUserMsgHeader<FrontendReq>, success: bool, ) -> Result<()>660 fn send_ack_message( 661 &mut self, 662 req: &VhostUserMsgHeader<FrontendReq>, 663 success: bool, 664 ) -> Result<()> { 665 if self.reply_ack_enabled && req.is_need_reply() { 666 let hdr: VhostUserMsgHeader<FrontendReq> = 667 self.new_reply_header::<VhostUserU64>(req, 0)?; 668 let val = if success { 0 } else { 1 }; 669 let msg = VhostUserU64::new(val); 670 self.connection.send_message(&hdr, &msg, None)?; 671 } 672 Ok(()) 673 } 674 send_reply_message<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, ) -> Result<()>675 fn send_reply_message<T: Sized + AsBytes>( 676 &mut self, 677 req: &VhostUserMsgHeader<FrontendReq>, 678 msg: &T, 679 ) -> Result<()> { 680 let hdr = self.new_reply_header::<T>(req, 0)?; 681 self.connection.send_message(&hdr, msg, None)?; 682 Ok(()) 683 } 684 send_reply_with_payload<T: Sized + AsBytes>( &mut self, req: &VhostUserMsgHeader<FrontendReq>, msg: &T, payload: &[u8], ) -> Result<()>685 fn send_reply_with_payload<T: Sized + AsBytes>( 686 &mut self, 687 req: &VhostUserMsgHeader<FrontendReq>, 688 msg: &T, 689 payload: &[u8], 690 ) -> Result<()> { 691 let hdr = self.new_reply_header::<T>(req, payload.len())?; 692 self.connection 693 .send_message_with_payload(&hdr, msg, payload, None)?; 694 Ok(()) 695 } 696 set_mem_table( &mut self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], files: Vec<File>, ) -> Result<()>697 fn set_mem_table( 698 &mut self, 699 hdr: &VhostUserMsgHeader<FrontendReq>, 700 size: usize, 701 buf: &[u8], 702 files: Vec<File>, 703 ) -> Result<()> { 704 self.check_request_size(hdr, size, hdr.get_size() as usize)?; 705 706 let (msg, regions) = 707 Ref::<_, VhostUserMemory>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 708 if !msg.is_valid() { 709 return Err(Error::InvalidMessage); 710 } 711 712 // validate number of fds matching number of memory regions 713 if files.len() != msg.num_regions as usize { 714 return Err(Error::InvalidMessage); 715 } 716 717 let (regions, excess) = Ref::<_, [VhostUserMemoryRegion]>::new_slice_from_prefix( 718 regions, 719 msg.num_regions as usize, 720 ) 721 .ok_or(Error::InvalidMessage)?; 722 if !excess.is_empty() { 723 return Err(Error::InvalidMessage); 724 } 725 726 // Validate memory regions 727 for region in regions.iter() { 728 if !region.is_valid() { 729 return Err(Error::InvalidMessage); 730 } 731 } 732 733 self.backend.set_mem_table(®ions, files) 734 } 735 get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()>736 fn get_config(&mut self, hdr: &VhostUserMsgHeader<FrontendReq>, buf: &[u8]) -> Result<()> { 737 let (msg, payload) = 738 Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 739 if !msg.is_valid() { 740 return Err(Error::InvalidMessage); 741 } 742 if payload.len() != msg.size as usize { 743 return Err(Error::InvalidMessage); 744 } 745 let flags = match VhostUserConfigFlags::from_bits(msg.flags) { 746 Some(val) => val, 747 None => return Err(Error::InvalidMessage), 748 }; 749 let res = self.backend.get_config(msg.offset, msg.size, flags); 750 751 // The response payload size MUST match the request payload size on success. A zero length 752 // response is used to indicate an error. 753 match res { 754 Ok(ref buf) if buf.len() == msg.size as usize => { 755 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); 756 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?; 757 } 758 Ok(_) => { 759 let reply = VhostUserConfig::new(msg.offset, 0, flags); 760 self.send_reply_message(hdr, &reply)?; 761 } 762 Err(_) => { 763 let reply = VhostUserConfig::new(msg.offset, 0, flags); 764 self.send_reply_message(hdr, &reply)?; 765 } 766 } 767 Ok(()) 768 } 769 set_config(&mut self, buf: &[u8]) -> Result<()>770 fn set_config(&mut self, buf: &[u8]) -> Result<()> { 771 let (msg, payload) = 772 Ref::<_, VhostUserConfig>::new_from_prefix(buf).ok_or(Error::InvalidMessage)?; 773 if !msg.is_valid() { 774 return Err(Error::InvalidMessage); 775 } 776 if payload.len() != msg.size as usize { 777 return Err(Error::InvalidMessage); 778 } 779 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) { 780 Some(val) => val, 781 None => return Err(Error::InvalidMessage), 782 }; 783 784 self.backend.set_config(msg.offset, payload, flags) 785 } 786 set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()>787 fn set_backend_req_fd(&mut self, files: Vec<File>) -> Result<()> { 788 let file = into_single_file(files).ok_or(Error::InvalidMessage)?; 789 let fd = file.into(); 790 // SAFETY: Safe because the protocol promises the file represents the appropriate file type 791 // for the platform. 792 let stream = unsafe { to_system_stream(fd) }?; 793 self.backend.set_backend_req_fd(Connection::from(stream)); 794 Ok(()) 795 } 796 797 /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a 798 /// Vring number and an fd. handle_vring_fd_request( &mut self, buf: &[u8], files: Vec<File>, ) -> Result<(u8, Option<File>)>799 fn handle_vring_fd_request( 800 &mut self, 801 buf: &[u8], 802 files: Vec<File>, 803 ) -> Result<(u8, Option<File>)> { 804 let msg = VhostUserU64::read_from_prefix(buf).ok_or(Error::InvalidMessage)?; 805 if !msg.is_valid() { 806 return Err(Error::InvalidMessage); 807 } 808 809 // Bits (0-7) of the payload contain the vring index. Bit 8 is the 810 // invalid FD flag (VHOST_USER_VRING_NOFD_MASK). 811 // This bit is set when there is no file descriptor 812 // in the ancillary data. This signals that polling will be used 813 // instead of waiting for the call. 814 // If Bit 8 is unset, the data must contain a file descriptor. 815 let has_fd = (msg.value & 0x100u64) == 0; 816 817 let file = into_single_file(files); 818 819 if has_fd && file.is_none() || !has_fd && file.is_some() { 820 return Err(Error::InvalidMessage); 821 } 822 823 Ok((msg.value as u8, file)) 824 } 825 check_request_size( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, expected: usize, ) -> Result<()>826 fn check_request_size( 827 &self, 828 hdr: &VhostUserMsgHeader<FrontendReq>, 829 size: usize, 830 expected: usize, 831 ) -> Result<()> { 832 if hdr.get_size() as usize != expected 833 || hdr.is_reply() 834 || hdr.get_version() != 0x1 835 || size != expected 836 { 837 return Err(Error::InvalidMessage); 838 } 839 Ok(()) 840 } 841 check_attached_files( &self, hdr: &VhostUserMsgHeader<FrontendReq>, files: &[File], ) -> Result<()>842 fn check_attached_files( 843 &self, 844 hdr: &VhostUserMsgHeader<FrontendReq>, 845 files: &[File], 846 ) -> Result<()> { 847 match hdr.get_code() { 848 Ok(FrontendReq::SET_MEM_TABLE) 849 | Ok(FrontendReq::SET_VRING_CALL) 850 | Ok(FrontendReq::SET_VRING_KICK) 851 | Ok(FrontendReq::SET_VRING_ERR) 852 | Ok(FrontendReq::SET_LOG_BASE) 853 | Ok(FrontendReq::SET_LOG_FD) 854 | Ok(FrontendReq::SET_BACKEND_REQ_FD) 855 | Ok(FrontendReq::SET_INFLIGHT_FD) 856 | Ok(FrontendReq::RESTORE) 857 | Ok(FrontendReq::ADD_MEM_REG) => Ok(()), 858 Err(_) => Err(Error::InvalidMessage), 859 _ if !files.is_empty() => Err(Error::InvalidMessage), 860 _ => Ok(()), 861 } 862 } 863 extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, size: usize, buf: &[u8], ) -> Result<T>864 fn extract_request_body<T: Sized + FromBytes + VhostUserMsgValidator>( 865 &self, 866 hdr: &VhostUserMsgHeader<FrontendReq>, 867 size: usize, 868 buf: &[u8], 869 ) -> Result<T> { 870 self.check_request_size(hdr, size, mem::size_of::<T>())?; 871 T::read_from_prefix(buf) 872 .filter(T::is_valid) 873 .map_or(Err(Error::InvalidMessage), Ok) 874 } 875 update_reply_ack_flag(&mut self)876 fn update_reply_ack_flag(&mut self) { 877 let pflag = VhostUserProtocolFeatures::REPLY_ACK; 878 self.reply_ack_enabled = (self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES) != 0 879 && self.protocol_features.contains(pflag) 880 && (self.acked_protocol_features & pflag.bits()) != 0; 881 } 882 } 883 884 impl<S: Backend> AsRawDescriptor for BackendServer<S> { as_raw_descriptor(&self) -> RawDescriptor885 fn as_raw_descriptor(&self) -> RawDescriptor { 886 // TODO(b/221882601): figure out if this used for polling. 887 self.connection.as_raw_descriptor() 888 } 889 } 890 891 #[cfg(test)] 892 mod tests { 893 use base::INVALID_DESCRIPTOR; 894 895 use super::*; 896 use crate::test_backend::TestBackend; 897 use crate::Connection; 898 use crate::SystemStream; 899 900 #[test] test_backend_server_new()901 fn test_backend_server_new() { 902 let (p1, _p2) = SystemStream::pair().unwrap(); 903 let connection = Connection::from(p1); 904 let backend = TestBackend::new(); 905 let handler = BackendServer::new(connection, backend); 906 907 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); 908 } 909 } 910