1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use base::{AsRawDescriptor, RawDescriptor}; 5 use std::fs::File; 6 use std::mem; 7 use std::slice; 8 use std::sync::{Arc, Mutex}; 9 10 use data_model::DataInit; 11 12 use super::connection::{Endpoint, EndpointExt}; 13 use super::message::*; 14 use super::{take_single_file, Error, Result}; 15 use crate::{MasterReqEndpoint, SystemStream}; 16 17 #[derive(PartialEq, Eq, Debug)] 18 /// Vhost-user protocol variants used for the communication. 19 pub enum Protocol { 20 /// Use the regular vhost-user protocol. 21 Regular, 22 /// Use the virtio-vhost-user protocol, which is proxied through virtqueues. 23 /// The protocol is mostly same as the vhost-user protocol but no file transfer is allowed. 24 Virtio, 25 } 26 27 /// Services provided to the master by the slave with interior mutability. 28 /// 29 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. 30 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], 31 /// but without interior mutability. 32 /// The vhost-user specification defines a master communication channel, by which masters could 33 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by 34 /// slaves, and it's used both on the master side and slave side. 35 /// 36 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy 37 /// service requests to slaves. 38 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler 39 /// implementing [VhostUserSlaveReqHandler]. 40 /// 41 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance 42 /// for multi-threading. 43 /// 44 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 45 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html 46 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 47 #[allow(missing_docs)] 48 pub trait VhostUserSlaveReqHandler { 49 /// Returns the type of vhost-user protocol that the handler support. protocol(&self) -> Protocol50 fn protocol(&self) -> Protocol; 51 set_owner(&self) -> Result<()>52 fn set_owner(&self) -> Result<()>; reset_owner(&self) -> Result<()>53 fn reset_owner(&self) -> Result<()>; get_features(&self) -> Result<u64>54 fn get_features(&self) -> Result<u64>; set_features(&self, features: u64) -> Result<()>55 fn set_features(&self, features: u64) -> Result<()>; set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>56 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&self, index: u32, num: u32) -> Result<()>57 fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>58 fn set_vring_addr( 59 &self, 60 index: u32, 61 flags: VhostUserVringAddrFlags, 62 descriptor: u64, 63 used: u64, 64 available: u64, 65 log: u64, 66 ) -> Result<()>; set_vring_base(&self, index: u32, base: u32) -> Result<()>67 fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; get_vring_base(&self, index: u32) -> Result<VhostUserVringState>68 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>69 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>70 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>71 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; 72 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>73 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&self, features: u64) -> Result<()>74 fn set_protocol_features(&self, features: u64) -> Result<()>; get_queue_num(&self) -> Result<u64>75 fn get_queue_num(&self) -> Result<u64>; set_vring_enable(&self, index: u32, enable: bool) -> Result<()>76 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>77 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>78 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&self, _vu_req: File)79 fn set_slave_req_fd(&self, _vu_req: File) {} get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>80 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>81 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&self) -> Result<u64>82 fn get_max_mem_slots(&self) -> Result<u64>; add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>83 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>84 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; 85 } 86 87 /// Services provided to the master by the slave without interior mutability. 88 /// 89 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. 90 #[allow(missing_docs)] 91 pub trait VhostUserSlaveReqHandlerMut { 92 /// Returns the type of vhost-user protocol that the handler support. protocol(&self) -> Protocol93 fn protocol(&self) -> Protocol; 94 set_owner(&mut self) -> Result<()>95 fn set_owner(&mut self) -> Result<()>; reset_owner(&mut self) -> Result<()>96 fn reset_owner(&mut self) -> Result<()>; get_features(&mut self) -> Result<u64>97 fn get_features(&mut self) -> Result<u64>; set_features(&mut self, features: u64) -> Result<()>98 fn set_features(&mut self, features: u64) -> Result<()>; set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>99 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&mut self, index: u32, num: u32) -> Result<()>100 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>101 fn set_vring_addr( 102 &mut self, 103 index: u32, 104 flags: VhostUserVringAddrFlags, 105 descriptor: u64, 106 used: u64, 107 available: u64, 108 log: u64, 109 ) -> Result<()>; set_vring_base(&mut self, index: u32, base: u32) -> Result<()>110 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>111 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>112 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>113 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>114 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; 115 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>116 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&mut self, features: u64) -> Result<()>117 fn set_protocol_features(&mut self, features: u64) -> Result<()>; get_queue_num(&mut self) -> Result<u64>118 fn get_queue_num(&mut self) -> Result<u64>; set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>119 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>120 fn get_config( 121 &mut self, 122 offset: u32, 123 size: u32, 124 flags: VhostUserConfigFlags, 125 ) -> Result<Vec<u8>>; set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>126 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&mut self, _vu_req: File)127 fn set_slave_req_fd(&mut self, _vu_req: File) {} get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>128 fn get_inflight_fd( 129 &mut self, 130 inflight: &VhostUserInflight, 131 ) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>132 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&mut self) -> Result<u64>133 fn get_max_mem_slots(&mut self) -> Result<u64>; add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>134 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>135 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; 136 } 137 138 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { protocol(&self) -> Protocol139 fn protocol(&self) -> Protocol { 140 self.lock().unwrap().protocol() 141 } 142 set_owner(&self) -> Result<()>143 fn set_owner(&self) -> Result<()> { 144 self.lock().unwrap().set_owner() 145 } 146 reset_owner(&self) -> Result<()>147 fn reset_owner(&self) -> Result<()> { 148 self.lock().unwrap().reset_owner() 149 } 150 get_features(&self) -> Result<u64>151 fn get_features(&self) -> Result<u64> { 152 self.lock().unwrap().get_features() 153 } 154 set_features(&self, features: u64) -> Result<()>155 fn set_features(&self, features: u64) -> Result<()> { 156 self.lock().unwrap().set_features(features) 157 } 158 set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>159 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 160 self.lock().unwrap().set_mem_table(ctx, files) 161 } 162 set_vring_num(&self, index: u32, num: u32) -> Result<()>163 fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { 164 self.lock().unwrap().set_vring_num(index, num) 165 } 166 set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>167 fn set_vring_addr( 168 &self, 169 index: u32, 170 flags: VhostUserVringAddrFlags, 171 descriptor: u64, 172 used: u64, 173 available: u64, 174 log: u64, 175 ) -> Result<()> { 176 self.lock() 177 .unwrap() 178 .set_vring_addr(index, flags, descriptor, used, available, log) 179 } 180 set_vring_base(&self, index: u32, base: u32) -> Result<()>181 fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { 182 self.lock().unwrap().set_vring_base(index, base) 183 } 184 get_vring_base(&self, index: u32) -> Result<VhostUserVringState>185 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { 186 self.lock().unwrap().get_vring_base(index) 187 } 188 set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>189 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { 190 self.lock().unwrap().set_vring_kick(index, fd) 191 } 192 set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>193 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { 194 self.lock().unwrap().set_vring_call(index, fd) 195 } 196 set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>197 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { 198 self.lock().unwrap().set_vring_err(index, fd) 199 } 200 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>201 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { 202 self.lock().unwrap().get_protocol_features() 203 } 204 set_protocol_features(&self, features: u64) -> Result<()>205 fn set_protocol_features(&self, features: u64) -> Result<()> { 206 self.lock().unwrap().set_protocol_features(features) 207 } 208 get_queue_num(&self) -> Result<u64>209 fn get_queue_num(&self) -> Result<u64> { 210 self.lock().unwrap().get_queue_num() 211 } 212 set_vring_enable(&self, index: u32, enable: bool) -> Result<()>213 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { 214 self.lock().unwrap().set_vring_enable(index, enable) 215 } 216 get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>217 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { 218 self.lock().unwrap().get_config(offset, size, flags) 219 } 220 set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>221 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 222 self.lock().unwrap().set_config(offset, buf, flags) 223 } 224 set_slave_req_fd(&self, vu_req: File)225 fn set_slave_req_fd(&self, vu_req: File) { 226 self.lock().unwrap().set_slave_req_fd(vu_req) 227 } 228 get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>229 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { 230 self.lock().unwrap().get_inflight_fd(inflight) 231 } 232 set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>233 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { 234 self.lock().unwrap().set_inflight_fd(inflight, file) 235 } 236 get_max_mem_slots(&self) -> Result<u64>237 fn get_max_mem_slots(&self) -> Result<u64> { 238 self.lock().unwrap().get_max_mem_slots() 239 } 240 add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>241 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 242 self.lock().unwrap().add_mem_region(region, fd) 243 } 244 remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>245 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 246 self.lock().unwrap().remove_mem_region(region) 247 } 248 } 249 250 /// Abstracts |Endpoint| related operations for vhost-user slave implementations. 251 pub struct SlaveReqHelper<E: Endpoint<MasterReq>> { 252 /// Underlying endpoint for communication. 253 endpoint: E, 254 255 /// Protocol used for the communication. 256 protocol: Protocol, 257 258 /// Sending ack for messages without payload. 259 reply_ack_enabled: bool, 260 } 261 262 impl<E: Endpoint<MasterReq>> SlaveReqHelper<E> { 263 /// Creates a new |SlaveReqHelper| instance with an |Endpoint| underneath it. new(endpoint: E, protocol: Protocol) -> Self264 pub fn new(endpoint: E, protocol: Protocol) -> Self { 265 SlaveReqHelper { 266 endpoint, 267 protocol, 268 reply_ack_enabled: false, 269 } 270 } 271 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>272 fn new_reply_header<T: Sized>( 273 &self, 274 req: &VhostUserMsgHeader<MasterReq>, 275 payload_size: usize, 276 ) -> Result<VhostUserMsgHeader<MasterReq>> { 277 if mem::size_of::<T>() > MAX_MSG_SIZE 278 || payload_size > MAX_MSG_SIZE 279 || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE 280 { 281 return Err(Error::InvalidParam); 282 } 283 284 Ok(VhostUserMsgHeader::new( 285 req.get_code(), 286 VhostUserHeaderFlag::REPLY.bits(), 287 (mem::size_of::<T>() + payload_size) as u32, 288 )) 289 } 290 291 /// Sends reply back to Vhost Master in response to a message. send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, success: bool, ) -> Result<()>292 pub fn send_ack_message( 293 &mut self, 294 req: &VhostUserMsgHeader<MasterReq>, 295 success: bool, 296 ) -> Result<()> { 297 if self.reply_ack_enabled && req.is_need_reply() { 298 let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; 299 let val = if success { 0 } else { 1 }; 300 let msg = VhostUserU64::new(val); 301 self.endpoint.send_message(&hdr, &msg, None)?; 302 } 303 Ok(()) 304 } 305 send_reply_message<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>306 fn send_reply_message<T: Sized + DataInit>( 307 &mut self, 308 req: &VhostUserMsgHeader<MasterReq>, 309 msg: &T, 310 ) -> Result<()> { 311 let hdr = self.new_reply_header::<T>(req, 0)?; 312 self.endpoint.send_message(&hdr, msg, None)?; 313 Ok(()) 314 } 315 send_reply_with_payload<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>316 fn send_reply_with_payload<T: Sized + DataInit>( 317 &mut self, 318 req: &VhostUserMsgHeader<MasterReq>, 319 msg: &T, 320 payload: &[u8], 321 ) -> Result<()> { 322 let hdr = self.new_reply_header::<T>(req, payload.len())?; 323 self.endpoint 324 .send_message_with_payload(&hdr, msg, payload, None)?; 325 Ok(()) 326 } 327 328 /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a 329 /// Vring number and an fd. handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>330 pub fn handle_vring_fd_request( 331 &mut self, 332 buf: &[u8], 333 files: Option<Vec<File>>, 334 ) -> Result<(u8, Option<File>)> { 335 if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { 336 return Err(Error::InvalidMessage); 337 } 338 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; 339 if !msg.is_valid() { 340 return Err(Error::InvalidMessage); 341 } 342 343 // Virtio-vhost-user protocol doesn't send FDs. 344 if self.protocol == Protocol::Virtio { 345 return Ok((msg.value as u8, None)); 346 } 347 348 // Bits (0-7) of the payload contain the vring index. Bit 8 is the 349 // invalid FD flag (VHOST_USER_VRING_NOFD_MASK). 350 // This bit is set when there is no file descriptor 351 // in the ancillary data. This signals that polling will be used 352 // instead of waiting for the call. 353 // If Bit 8 is unset, the data must contain a file descriptor. 354 let has_fd = (msg.value & 0x100u64) == 0; 355 356 let file = take_single_file(files); 357 358 if has_fd && file.is_none() || !has_fd && file.is_some() { 359 return Err(Error::InvalidMessage); 360 } 361 362 Ok((msg.value as u8, file)) 363 } 364 } 365 366 impl<E: Endpoint<MasterReq>> AsRef<E> for SlaveReqHelper<E> { as_ref(&self) -> &E367 fn as_ref(&self) -> &E { 368 &self.endpoint 369 } 370 } 371 372 impl<E: Endpoint<MasterReq>> AsMut<E> for SlaveReqHelper<E> { as_mut(&mut self) -> &mut E373 fn as_mut(&mut self) -> &mut E { 374 &mut self.endpoint 375 } 376 } 377 378 impl<E: Endpoint<MasterReq> + AsRawDescriptor> AsRawDescriptor for SlaveReqHelper<E> { as_raw_descriptor(&self) -> RawDescriptor379 fn as_raw_descriptor(&self) -> RawDescriptor { 380 self.endpoint.as_raw_descriptor() 381 } 382 } 383 384 /// Server to handle service requests from masters from the master communication channel. 385 /// 386 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from 387 /// masters on the master communication channel. It's actually a proxy invoking the registered 388 /// handler implementing [VhostUserSlaveReqHandler] to do the real work. 389 /// 390 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain 391 /// Socket, so it gets simpler to recover from disconnect. 392 /// 393 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 394 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 395 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> { 396 slave_req_helper: SlaveReqHelper<E>, 397 // the vhost-user backend device object 398 backend: Arc<S>, 399 400 virtio_features: u64, 401 acked_virtio_features: u64, 402 protocol_features: VhostUserProtocolFeatures, 403 acked_protocol_features: u64, 404 405 // whether the endpoint has encountered any failure 406 error: Option<i32>, 407 } 408 409 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S, MasterReqEndpoint> { 410 /// Create a vhost-user slave endpoint from a connected socket. from_stream(socket: SystemStream, backend: Arc<S>) -> Self411 pub fn from_stream(socket: SystemStream, backend: Arc<S>) -> Self { 412 Self::new(MasterReqEndpoint::from(socket), backend) 413 } 414 } 415 416 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E> { 417 /// Create a vhost-user slave endpoint. new(endpoint: E, backend: Arc<S>) -> Self418 pub(super) fn new(endpoint: E, backend: Arc<S>) -> Self { 419 SlaveReqHandler { 420 slave_req_helper: SlaveReqHelper::new(endpoint, backend.protocol()), 421 backend, 422 virtio_features: 0, 423 acked_virtio_features: 0, 424 protocol_features: VhostUserProtocolFeatures::empty(), 425 acked_protocol_features: 0, 426 error: None, 427 } 428 } 429 430 /// Create a new vhost-user slave endpoint. 431 /// 432 /// # Arguments 433 /// * - `path` - path of Unix domain socket listener to connect to 434 /// * - `backend` - handler for requests from the master to the slave connect(path: &str, backend: Arc<S>) -> Result<Self>435 pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { 436 Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) 437 } 438 439 /// Mark endpoint as failed with specified error code. set_failed(&mut self, error: i32)440 pub fn set_failed(&mut self, error: i32) { 441 self.error = Some(error); 442 } 443 444 /// Main entrance to request from the communication channel. 445 /// 446 /// Receive and handle one incoming request message from the vmm. The caller needs to: 447 /// - serialize calls to this function 448 /// - decide what to do when error happens 449 /// - optional recover from failure 450 /// 451 /// # Return: 452 /// * - `Ok(())`: one request was successfully handled. 453 /// * - `Err(ClientExit)`: the vmm closed the connection properly. This isn't an actual failure. 454 /// * - `Err(Disconnect)`: the connection was closed unexpectedly. 455 /// * - `Err(InvalidMessage)`: the vmm sent a illegal message. 456 /// * - other errors: failed to handle a request. handle_request(&mut self) -> Result<()>457 pub fn handle_request(&mut self) -> Result<()> { 458 // Return error if the endpoint is already in failed state. 459 self.check_state()?; 460 461 // The underlying communication channel is a Unix domain socket in 462 // stream mode, and recvmsg() is a little tricky here. To successfully 463 // receive attached file descriptors, we need to receive messages and 464 // corresponding attached file descriptors in this way: 465 // . recv messsage header and optional attached file 466 // . validate message header 467 // . recv optional message body and payload according size field in 468 // message header 469 // . validate message body and optional payload 470 let (hdr, files) = match self.slave_req_helper.endpoint.recv_header() { 471 Ok((hdr, files)) => (hdr, files), 472 Err(Error::Disconnect) => { 473 // If the client closed the connection before sending a header, this should be 474 // handled as a legal exit. 475 return Err(Error::ClientExit); 476 } 477 Err(e) => { 478 return Err(e); 479 } 480 }; 481 482 self.check_attached_files(&hdr, &files)?; 483 484 let buf = match hdr.get_size() { 485 0 => vec![0u8; 0], 486 len => { 487 let rbuf = self.slave_req_helper.endpoint.recv_data(len as usize)?; 488 if rbuf.len() != len as usize { 489 return Err(Error::InvalidMessage); 490 } 491 rbuf 492 } 493 }; 494 let size = buf.len(); 495 496 match hdr.get_code() { 497 MasterReq::SET_OWNER => { 498 self.check_request_size(&hdr, size, 0)?; 499 let res = self.backend.set_owner(); 500 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 501 res?; 502 } 503 MasterReq::RESET_OWNER => { 504 self.check_request_size(&hdr, size, 0)?; 505 let res = self.backend.reset_owner(); 506 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 507 res?; 508 } 509 MasterReq::GET_FEATURES => { 510 self.check_request_size(&hdr, size, 0)?; 511 let features = self.backend.get_features()?; 512 let msg = VhostUserU64::new(features); 513 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 514 self.virtio_features = features; 515 self.update_reply_ack_flag(); 516 } 517 MasterReq::SET_FEATURES => { 518 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 519 let res = self.backend.set_features(msg.value); 520 self.acked_virtio_features = msg.value; 521 self.update_reply_ack_flag(); 522 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 523 res?; 524 } 525 MasterReq::SET_MEM_TABLE => { 526 let res = self.set_mem_table(&hdr, size, &buf, files); 527 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 528 res?; 529 } 530 MasterReq::SET_VRING_NUM => { 531 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 532 let res = self.backend.set_vring_num(msg.index, msg.num); 533 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 534 res?; 535 } 536 MasterReq::SET_VRING_ADDR => { 537 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; 538 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { 539 Some(val) => val, 540 None => return Err(Error::InvalidMessage), 541 }; 542 let res = self.backend.set_vring_addr( 543 msg.index, 544 flags, 545 msg.descriptor, 546 msg.used, 547 msg.available, 548 msg.log, 549 ); 550 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 551 res?; 552 } 553 MasterReq::SET_VRING_BASE => { 554 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 555 let res = self.backend.set_vring_base(msg.index, msg.num); 556 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 557 res?; 558 } 559 MasterReq::GET_VRING_BASE => { 560 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 561 let reply = self.backend.get_vring_base(msg.index)?; 562 self.slave_req_helper.send_reply_message(&hdr, &reply)?; 563 } 564 MasterReq::SET_VRING_CALL => { 565 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 566 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 567 let res = self.backend.set_vring_call(index, file); 568 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 569 res?; 570 } 571 MasterReq::SET_VRING_KICK => { 572 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 573 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 574 let res = self.backend.set_vring_kick(index, file); 575 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 576 res?; 577 } 578 MasterReq::SET_VRING_ERR => { 579 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 580 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 581 let res = self.backend.set_vring_err(index, file); 582 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 583 res?; 584 } 585 MasterReq::GET_PROTOCOL_FEATURES => { 586 self.check_request_size(&hdr, size, 0)?; 587 let features = self.backend.get_protocol_features()?; 588 let msg = VhostUserU64::new(features.bits()); 589 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 590 self.protocol_features = features; 591 self.update_reply_ack_flag(); 592 } 593 MasterReq::SET_PROTOCOL_FEATURES => { 594 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 595 let res = self.backend.set_protocol_features(msg.value); 596 self.acked_protocol_features = msg.value; 597 self.update_reply_ack_flag(); 598 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 599 res?; 600 } 601 MasterReq::GET_QUEUE_NUM => { 602 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { 603 return Err(Error::InvalidOperation); 604 } 605 self.check_request_size(&hdr, size, 0)?; 606 let num = self.backend.get_queue_num()?; 607 let msg = VhostUserU64::new(num); 608 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 609 } 610 MasterReq::SET_VRING_ENABLE => { 611 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 612 if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() 613 == 0 614 { 615 return Err(Error::InvalidOperation); 616 } 617 let enable = match msg.num { 618 1 => true, 619 0 => false, 620 _ => return Err(Error::InvalidParam), 621 }; 622 623 let res = self.backend.set_vring_enable(msg.index, enable); 624 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 625 res?; 626 } 627 MasterReq::GET_CONFIG => { 628 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 629 return Err(Error::InvalidOperation); 630 } 631 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 632 self.get_config(&hdr, &buf)?; 633 } 634 MasterReq::SET_CONFIG => { 635 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 636 return Err(Error::InvalidOperation); 637 } 638 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 639 let res = self.set_config(size, &buf); 640 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 641 res?; 642 } 643 MasterReq::SET_SLAVE_REQ_FD => { 644 if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { 645 return Err(Error::InvalidOperation); 646 } 647 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 648 let res = self.set_slave_req_fd(files); 649 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 650 res?; 651 } 652 MasterReq::GET_INFLIGHT_FD => { 653 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 654 == 0 655 { 656 return Err(Error::InvalidOperation); 657 } 658 659 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 660 let (inflight, file) = self.backend.get_inflight_fd(&msg)?; 661 let reply_hdr = self 662 .slave_req_helper 663 .new_reply_header::<VhostUserInflight>(&hdr, 0)?; 664 self.slave_req_helper.endpoint.send_message( 665 &reply_hdr, 666 &inflight, 667 Some(&[file.as_raw_descriptor()]), 668 )?; 669 } 670 MasterReq::SET_INFLIGHT_FD => { 671 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 672 == 0 673 { 674 return Err(Error::InvalidOperation); 675 } 676 let file = take_single_file(files).ok_or(Error::IncorrectFds)?; 677 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 678 let res = self.backend.set_inflight_fd(&msg, file); 679 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 680 res?; 681 } 682 MasterReq::GET_MAX_MEM_SLOTS => { 683 if self.acked_protocol_features 684 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 685 == 0 686 { 687 return Err(Error::InvalidOperation); 688 } 689 self.check_request_size(&hdr, size, 0)?; 690 let num = self.backend.get_max_mem_slots()?; 691 let msg = VhostUserU64::new(num); 692 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 693 } 694 MasterReq::ADD_MEM_REG => { 695 if self.acked_protocol_features 696 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 697 == 0 698 { 699 return Err(Error::InvalidOperation); 700 } 701 let mut files = files.ok_or(Error::InvalidParam)?; 702 if files.len() != 1 { 703 return Err(Error::InvalidParam); 704 } 705 let msg = 706 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 707 let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); 708 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 709 res?; 710 } 711 MasterReq::REM_MEM_REG => { 712 if self.acked_protocol_features 713 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 714 == 0 715 { 716 return Err(Error::InvalidOperation); 717 } 718 719 let msg = 720 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 721 let res = self.backend.remove_mem_region(&msg); 722 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 723 res?; 724 } 725 _ => { 726 return Err(Error::InvalidMessage); 727 } 728 } 729 Ok(()) 730 } 731 set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>732 fn set_mem_table( 733 &mut self, 734 hdr: &VhostUserMsgHeader<MasterReq>, 735 size: usize, 736 buf: &[u8], 737 files: Option<Vec<File>>, 738 ) -> Result<()> { 739 self.check_request_size(hdr, size, hdr.get_size() as usize)?; 740 741 // check message size is consistent 742 let hdrsize = mem::size_of::<VhostUserMemory>(); 743 if size < hdrsize { 744 return Err(Error::InvalidMessage); 745 } 746 let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; 747 if !msg.is_valid() { 748 return Err(Error::InvalidMessage); 749 } 750 if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { 751 return Err(Error::InvalidMessage); 752 } 753 754 let files = match self.slave_req_helper.protocol { 755 Protocol::Regular => { 756 // validate number of fds matching number of memory regions 757 let files = files.ok_or(Error::InvalidMessage)?; 758 if files.len() != msg.num_regions as usize { 759 return Err(Error::InvalidMessage); 760 } 761 files 762 } 763 Protocol::Virtio => vec![], 764 }; 765 766 // Validate memory regions 767 let regions = unsafe { 768 slice::from_raw_parts( 769 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, 770 msg.num_regions as usize, 771 ) 772 }; 773 for region in regions.iter() { 774 if !region.is_valid() { 775 return Err(Error::InvalidMessage); 776 } 777 } 778 779 self.backend.set_mem_table(regions, files) 780 } 781 get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>782 fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { 783 let payload_offset = mem::size_of::<VhostUserConfig>(); 784 if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { 785 return Err(Error::InvalidMessage); 786 } 787 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 788 if !msg.is_valid() { 789 return Err(Error::InvalidMessage); 790 } 791 if buf.len() - payload_offset != msg.size as usize { 792 return Err(Error::InvalidMessage); 793 } 794 let flags = match VhostUserConfigFlags::from_bits(msg.flags) { 795 Some(val) => val, 796 None => return Err(Error::InvalidMessage), 797 }; 798 let res = self.backend.get_config(msg.offset, msg.size, flags); 799 800 // vhost-user slave's payload size MUST match master's request 801 // on success, uses zero length of payload to indicate an error 802 // to vhost-user master. 803 match res { 804 Ok(ref buf) if buf.len() == msg.size as usize => { 805 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); 806 self.slave_req_helper 807 .send_reply_with_payload(hdr, &reply, buf.as_slice())?; 808 } 809 Ok(_) => { 810 let reply = VhostUserConfig::new(msg.offset, 0, flags); 811 self.slave_req_helper.send_reply_message(hdr, &reply)?; 812 } 813 Err(_) => { 814 let reply = VhostUserConfig::new(msg.offset, 0, flags); 815 self.slave_req_helper.send_reply_message(hdr, &reply)?; 816 } 817 } 818 Ok(()) 819 } 820 set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>821 fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { 822 if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { 823 return Err(Error::InvalidMessage); 824 } 825 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 826 if !msg.is_valid() { 827 return Err(Error::InvalidMessage); 828 } 829 if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { 830 return Err(Error::InvalidMessage); 831 } 832 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) { 833 Some(val) => val, 834 None => return Err(Error::InvalidMessage), 835 }; 836 837 self.backend.set_config(msg.offset, buf, flags) 838 } 839 set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>840 fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { 841 if cfg!(windows) { 842 unimplemented!(); 843 } else { 844 let file = take_single_file(files).ok_or(Error::InvalidMessage)?; 845 self.backend.set_slave_req_fd(file); 846 Ok(()) 847 } 848 } 849 handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>850 fn handle_vring_fd_request( 851 &mut self, 852 buf: &[u8], 853 files: Option<Vec<File>>, 854 ) -> Result<(u8, Option<File>)> { 855 self.slave_req_helper.handle_vring_fd_request(buf, files) 856 } 857 check_state(&self) -> Result<()>858 fn check_state(&self) -> Result<()> { 859 match self.error { 860 Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), 861 None => Ok(()), 862 } 863 } 864 check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>865 fn check_request_size( 866 &self, 867 hdr: &VhostUserMsgHeader<MasterReq>, 868 size: usize, 869 expected: usize, 870 ) -> Result<()> { 871 if hdr.get_size() as usize != expected 872 || hdr.is_reply() 873 || hdr.get_version() != 0x1 874 || size != expected 875 { 876 return Err(Error::InvalidMessage); 877 } 878 Ok(()) 879 } 880 check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>881 fn check_attached_files( 882 &self, 883 hdr: &VhostUserMsgHeader<MasterReq>, 884 files: &Option<Vec<File>>, 885 ) -> Result<()> { 886 match hdr.get_code() { 887 MasterReq::SET_MEM_TABLE 888 | MasterReq::SET_VRING_CALL 889 | MasterReq::SET_VRING_KICK 890 | MasterReq::SET_VRING_ERR 891 | MasterReq::SET_LOG_BASE 892 | MasterReq::SET_LOG_FD 893 | MasterReq::SET_SLAVE_REQ_FD 894 | MasterReq::SET_INFLIGHT_FD 895 | MasterReq::ADD_MEM_REG => Ok(()), 896 _ if files.is_some() => Err(Error::InvalidMessage), 897 _ => Ok(()), 898 } 899 } 900 extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>901 fn extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( 902 &self, 903 hdr: &VhostUserMsgHeader<MasterReq>, 904 size: usize, 905 buf: &[u8], 906 ) -> Result<T> { 907 self.check_request_size(hdr, size, mem::size_of::<T>())?; 908 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 909 if !msg.is_valid() { 910 return Err(Error::InvalidMessage); 911 } 912 Ok(msg) 913 } 914 update_reply_ack_flag(&mut self)915 fn update_reply_ack_flag(&mut self) { 916 let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); 917 let pflag = VhostUserProtocolFeatures::REPLY_ACK; 918 if (self.virtio_features & vflag) != 0 919 && self.protocol_features.contains(pflag) 920 && (self.acked_protocol_features & pflag.bits()) != 0 921 { 922 self.slave_req_helper.reply_ack_enabled = true; 923 } else { 924 self.slave_req_helper.reply_ack_enabled = false; 925 } 926 } 927 } 928 929 impl<S: VhostUserSlaveReqHandler, E: AsRawDescriptor + Endpoint<MasterReq>> AsRawDescriptor 930 for SlaveReqHandler<S, E> 931 { as_raw_descriptor(&self) -> RawDescriptor932 fn as_raw_descriptor(&self) -> RawDescriptor { 933 // TODO(b/221882601): figure out if this used for polling. 934 self.slave_req_helper.endpoint.as_raw_descriptor() 935 } 936 } 937 938 #[cfg(test)] 939 mod tests { 940 use base::INVALID_DESCRIPTOR; 941 942 use super::*; 943 use crate::{dummy_slave::DummySlaveReqHandler, MasterReqEndpoint, SystemStream}; 944 945 #[test] test_slave_req_handler_new()946 fn test_slave_req_handler_new() { 947 let (p1, _p2) = SystemStream::pair().unwrap(); 948 let endpoint = MasterReqEndpoint::from(p1); 949 let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); 950 let mut handler = SlaveReqHandler::new(endpoint, backend); 951 952 handler.check_state().unwrap(); 953 handler.set_failed(libc::EAGAIN); 954 handler.check_state().unwrap_err(); 955 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); 956 } 957 } 958