1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 use std::slice; 7 use std::sync::Mutex; 8 9 use base::AsRawDescriptor; 10 use base::RawDescriptor; 11 use data_model::DataInit; 12 use zerocopy::AsBytes; 13 14 use crate::connection::Endpoint; 15 use crate::connection::EndpointExt; 16 use crate::message::*; 17 use crate::take_single_file; 18 use crate::Error; 19 use crate::MasterReqEndpoint; 20 use crate::Result; 21 use crate::SystemStream; 22 23 #[derive(PartialEq, Eq, Debug)] 24 /// Vhost-user protocol variants used for the communication. 25 pub enum Protocol { 26 /// Use the regular vhost-user protocol. 27 Regular, 28 /// Use the virtio-vhost-user protocol, which is proxied through virtqueues. 29 /// The protocol is mostly same as the vhost-user protocol but no file transfer is allowed. 30 Virtio, 31 } 32 33 impl Protocol { 34 /// Returns whether the protocol assumes that messages are sent in stream mode like Unix's SOCK_STREAM. 35 /// 36 /// In stream mode, the receivers cannot know the size of the entire message in advance so a 37 /// message header with the body size and the message body will be sent separately. See 38 /// [`SlaveReqHandler::recv_header()`]'s doc comment for more details. is_stream_mode(&self) -> bool39 fn is_stream_mode(&self) -> bool { 40 match self { 41 Protocol::Regular => true, 42 // VVU proxy sends a message header and its payload at once. 43 Protocol::Virtio => false, 44 } 45 } 46 } 47 48 /// Services provided to the master by the slave with interior mutability. 49 /// 50 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. 51 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], 52 /// but without interior mutability. 53 /// The vhost-user specification defines a master communication channel, by which masters could 54 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by 55 /// slaves, and it's used both on the master side and slave side. 56 /// 57 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy 58 /// service requests to slaves. 59 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler 60 /// implementing [VhostUserSlaveReqHandler]. 61 /// 62 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance 63 /// for multi-threading. 64 /// 65 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 66 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html 67 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 68 #[allow(missing_docs)] 69 pub trait VhostUserSlaveReqHandler { 70 /// Returns the type of vhost-user protocol that the handler support. protocol(&self) -> Protocol71 fn protocol(&self) -> Protocol; 72 set_owner(&self) -> Result<()>73 fn set_owner(&self) -> Result<()>; reset_owner(&self) -> Result<()>74 fn reset_owner(&self) -> Result<()>; get_features(&self) -> Result<u64>75 fn get_features(&self) -> Result<u64>; set_features(&self, features: u64) -> Result<()>76 fn set_features(&self, features: u64) -> Result<()>; set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>77 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&self, index: u32, num: u32) -> Result<()>78 fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>79 fn set_vring_addr( 80 &self, 81 index: u32, 82 flags: VhostUserVringAddrFlags, 83 descriptor: u64, 84 used: u64, 85 available: u64, 86 log: u64, 87 ) -> Result<()>; set_vring_base(&self, index: u32, base: u32) -> Result<()>88 fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; get_vring_base(&self, index: u32) -> Result<VhostUserVringState>89 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>90 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>91 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>92 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; 93 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>94 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&self, features: u64) -> Result<()>95 fn set_protocol_features(&self, features: u64) -> Result<()>; get_queue_num(&self) -> Result<u64>96 fn get_queue_num(&self) -> Result<u64>; set_vring_enable(&self, index: u32, enable: bool) -> Result<()>97 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>98 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>99 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&self, _vu_req: Box<dyn Endpoint<SlaveReq>>)100 fn set_slave_req_fd(&self, _vu_req: Box<dyn Endpoint<SlaveReq>>) {} get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>101 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>102 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&self) -> Result<u64>103 fn get_max_mem_slots(&self) -> Result<u64>; add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>104 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>105 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>106 fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>; 107 } 108 109 /// Services provided to the master by the slave without interior mutability. 110 /// 111 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. 112 #[allow(missing_docs)] 113 pub trait VhostUserSlaveReqHandlerMut { 114 /// Returns the type of vhost-user protocol that the handler support. protocol(&self) -> Protocol115 fn protocol(&self) -> Protocol; 116 set_owner(&mut self) -> Result<()>117 fn set_owner(&mut self) -> Result<()>; reset_owner(&mut self) -> Result<()>118 fn reset_owner(&mut self) -> Result<()>; get_features(&mut self) -> Result<u64>119 fn get_features(&mut self) -> Result<u64>; set_features(&mut self, features: u64) -> Result<()>120 fn set_features(&mut self, features: u64) -> Result<()>; set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>121 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&mut self, index: u32, num: u32) -> Result<()>122 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>123 fn set_vring_addr( 124 &mut self, 125 index: u32, 126 flags: VhostUserVringAddrFlags, 127 descriptor: u64, 128 used: u64, 129 available: u64, 130 log: u64, 131 ) -> Result<()>; set_vring_base(&mut self, index: u32, base: u32) -> Result<()>132 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>133 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>134 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>135 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>136 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; 137 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>138 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&mut self, features: u64) -> Result<()>139 fn set_protocol_features(&mut self, features: u64) -> Result<()>; get_queue_num(&mut self) -> Result<u64>140 fn get_queue_num(&mut self) -> Result<u64>; set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>141 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>142 fn get_config( 143 &mut self, 144 offset: u32, 145 size: u32, 146 flags: VhostUserConfigFlags, 147 ) -> Result<Vec<u8>>; set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>148 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&mut self, _vu_req: Box<dyn Endpoint<SlaveReq>>)149 fn set_slave_req_fd(&mut self, _vu_req: Box<dyn Endpoint<SlaveReq>>) {} get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>150 fn get_inflight_fd( 151 &mut self, 152 inflight: &VhostUserInflight, 153 ) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>154 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&mut self) -> Result<u64>155 fn get_max_mem_slots(&mut self) -> Result<u64>; add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>156 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>157 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>158 fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>; 159 } 160 161 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { protocol(&self) -> Protocol162 fn protocol(&self) -> Protocol { 163 self.lock().unwrap().protocol() 164 } 165 set_owner(&self) -> Result<()>166 fn set_owner(&self) -> Result<()> { 167 self.lock().unwrap().set_owner() 168 } 169 reset_owner(&self) -> Result<()>170 fn reset_owner(&self) -> Result<()> { 171 self.lock().unwrap().reset_owner() 172 } 173 get_features(&self) -> Result<u64>174 fn get_features(&self) -> Result<u64> { 175 self.lock().unwrap().get_features() 176 } 177 set_features(&self, features: u64) -> Result<()>178 fn set_features(&self, features: u64) -> Result<()> { 179 self.lock().unwrap().set_features(features) 180 } 181 set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>182 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 183 self.lock().unwrap().set_mem_table(ctx, files) 184 } 185 set_vring_num(&self, index: u32, num: u32) -> Result<()>186 fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { 187 self.lock().unwrap().set_vring_num(index, num) 188 } 189 set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>190 fn set_vring_addr( 191 &self, 192 index: u32, 193 flags: VhostUserVringAddrFlags, 194 descriptor: u64, 195 used: u64, 196 available: u64, 197 log: u64, 198 ) -> Result<()> { 199 self.lock() 200 .unwrap() 201 .set_vring_addr(index, flags, descriptor, used, available, log) 202 } 203 set_vring_base(&self, index: u32, base: u32) -> Result<()>204 fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { 205 self.lock().unwrap().set_vring_base(index, base) 206 } 207 get_vring_base(&self, index: u32) -> Result<VhostUserVringState>208 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { 209 self.lock().unwrap().get_vring_base(index) 210 } 211 set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>212 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { 213 self.lock().unwrap().set_vring_kick(index, fd) 214 } 215 set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>216 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { 217 self.lock().unwrap().set_vring_call(index, fd) 218 } 219 set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>220 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { 221 self.lock().unwrap().set_vring_err(index, fd) 222 } 223 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>224 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { 225 self.lock().unwrap().get_protocol_features() 226 } 227 set_protocol_features(&self, features: u64) -> Result<()>228 fn set_protocol_features(&self, features: u64) -> Result<()> { 229 self.lock().unwrap().set_protocol_features(features) 230 } 231 get_queue_num(&self) -> Result<u64>232 fn get_queue_num(&self) -> Result<u64> { 233 self.lock().unwrap().get_queue_num() 234 } 235 set_vring_enable(&self, index: u32, enable: bool) -> Result<()>236 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { 237 self.lock().unwrap().set_vring_enable(index, enable) 238 } 239 get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>240 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { 241 self.lock().unwrap().get_config(offset, size, flags) 242 } 243 set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>244 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 245 self.lock().unwrap().set_config(offset, buf, flags) 246 } 247 set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>)248 fn set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>) { 249 self.lock().unwrap().set_slave_req_fd(vu_req) 250 } 251 get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>252 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { 253 self.lock().unwrap().get_inflight_fd(inflight) 254 } 255 set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>256 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { 257 self.lock().unwrap().set_inflight_fd(inflight, file) 258 } 259 get_max_mem_slots(&self) -> Result<u64>260 fn get_max_mem_slots(&self) -> Result<u64> { 261 self.lock().unwrap().get_max_mem_slots() 262 } 263 add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>264 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 265 self.lock().unwrap().add_mem_region(region, fd) 266 } 267 remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>268 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 269 self.lock().unwrap().remove_mem_region(region) 270 } 271 get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>272 fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>> { 273 self.lock().unwrap().get_shared_memory_regions() 274 } 275 } 276 277 impl<T> VhostUserSlaveReqHandler for T 278 where 279 T: AsRef<dyn VhostUserSlaveReqHandler>, 280 { protocol(&self) -> Protocol281 fn protocol(&self) -> Protocol { 282 self.as_ref().protocol() 283 } 284 set_owner(&self) -> Result<()>285 fn set_owner(&self) -> Result<()> { 286 self.as_ref().set_owner() 287 } 288 reset_owner(&self) -> Result<()>289 fn reset_owner(&self) -> Result<()> { 290 self.as_ref().reset_owner() 291 } 292 get_features(&self) -> Result<u64>293 fn get_features(&self) -> Result<u64> { 294 self.as_ref().get_features() 295 } 296 set_features(&self, features: u64) -> Result<()>297 fn set_features(&self, features: u64) -> Result<()> { 298 self.as_ref().set_features(features) 299 } 300 set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>301 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 302 self.as_ref().set_mem_table(ctx, files) 303 } 304 set_vring_num(&self, index: u32, num: u32) -> Result<()>305 fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { 306 self.as_ref().set_vring_num(index, num) 307 } 308 set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>309 fn set_vring_addr( 310 &self, 311 index: u32, 312 flags: VhostUserVringAddrFlags, 313 descriptor: u64, 314 used: u64, 315 available: u64, 316 log: u64, 317 ) -> Result<()> { 318 self.as_ref() 319 .set_vring_addr(index, flags, descriptor, used, available, log) 320 } 321 set_vring_base(&self, index: u32, base: u32) -> Result<()>322 fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { 323 self.as_ref().set_vring_base(index, base) 324 } 325 get_vring_base(&self, index: u32) -> Result<VhostUserVringState>326 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { 327 self.as_ref().get_vring_base(index) 328 } 329 set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>330 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { 331 self.as_ref().set_vring_kick(index, fd) 332 } 333 set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>334 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { 335 self.as_ref().set_vring_call(index, fd) 336 } 337 set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>338 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { 339 self.as_ref().set_vring_err(index, fd) 340 } 341 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>342 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { 343 self.as_ref().get_protocol_features() 344 } 345 set_protocol_features(&self, features: u64) -> Result<()>346 fn set_protocol_features(&self, features: u64) -> Result<()> { 347 self.as_ref().set_protocol_features(features) 348 } 349 get_queue_num(&self) -> Result<u64>350 fn get_queue_num(&self) -> Result<u64> { 351 self.as_ref().get_queue_num() 352 } 353 set_vring_enable(&self, index: u32, enable: bool) -> Result<()>354 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { 355 self.as_ref().set_vring_enable(index, enable) 356 } 357 get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>358 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { 359 self.as_ref().get_config(offset, size, flags) 360 } 361 set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>362 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 363 self.as_ref().set_config(offset, buf, flags) 364 } 365 set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>)366 fn set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>) { 367 self.as_ref().set_slave_req_fd(vu_req) 368 } 369 get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>370 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { 371 self.as_ref().get_inflight_fd(inflight) 372 } 373 set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>374 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { 375 self.as_ref().set_inflight_fd(inflight, file) 376 } 377 get_max_mem_slots(&self) -> Result<u64>378 fn get_max_mem_slots(&self) -> Result<u64> { 379 self.as_ref().get_max_mem_slots() 380 } 381 add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>382 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 383 self.as_ref().add_mem_region(region, fd) 384 } 385 remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>386 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 387 self.as_ref().remove_mem_region(region) 388 } 389 get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>390 fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>> { 391 self.as_ref().get_shared_memory_regions() 392 } 393 } 394 395 /// Abstracts |Endpoint| related operations for vhost-user slave implementations. 396 pub struct SlaveReqHelper<E: Endpoint<MasterReq>> { 397 /// Underlying endpoint for communication. 398 endpoint: E, 399 400 /// Protocol used for the communication. 401 protocol: Protocol, 402 403 /// Sending ack for messages without payload. 404 reply_ack_enabled: bool, 405 } 406 407 impl<E: Endpoint<MasterReq>> SlaveReqHelper<E> { 408 /// Creates a new |SlaveReqHelper| instance with an |Endpoint| underneath it. new(endpoint: E, protocol: Protocol) -> Self409 pub fn new(endpoint: E, protocol: Protocol) -> Self { 410 SlaveReqHelper { 411 endpoint, 412 protocol, 413 reply_ack_enabled: false, 414 } 415 } 416 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>417 fn new_reply_header<T: Sized>( 418 &self, 419 req: &VhostUserMsgHeader<MasterReq>, 420 payload_size: usize, 421 ) -> Result<VhostUserMsgHeader<MasterReq>> { 422 if mem::size_of::<T>() > MAX_MSG_SIZE 423 || payload_size > MAX_MSG_SIZE 424 || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE 425 { 426 return Err(Error::InvalidParam); 427 } 428 429 Ok(VhostUserMsgHeader::new( 430 req.get_code(), 431 VhostUserHeaderFlag::REPLY.bits(), 432 (mem::size_of::<T>() + payload_size) as u32, 433 )) 434 } 435 436 /// Sends reply back to Vhost Master in response to a message. send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, success: bool, ) -> Result<()>437 pub fn send_ack_message( 438 &mut self, 439 req: &VhostUserMsgHeader<MasterReq>, 440 success: bool, 441 ) -> Result<()> { 442 if self.reply_ack_enabled && req.is_need_reply() { 443 let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; 444 let val = if success { 0 } else { 1 }; 445 let msg = VhostUserU64::new(val); 446 self.endpoint.send_message(&hdr, &msg, None)?; 447 } 448 Ok(()) 449 } 450 send_reply_message<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>451 fn send_reply_message<T: Sized + DataInit>( 452 &mut self, 453 req: &VhostUserMsgHeader<MasterReq>, 454 msg: &T, 455 ) -> Result<()> { 456 let hdr = self.new_reply_header::<T>(req, 0)?; 457 self.endpoint.send_message(&hdr, msg, None)?; 458 Ok(()) 459 } 460 send_reply_with_payload<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>461 fn send_reply_with_payload<T: Sized + DataInit>( 462 &mut self, 463 req: &VhostUserMsgHeader<MasterReq>, 464 msg: &T, 465 payload: &[u8], 466 ) -> Result<()> { 467 let hdr = self.new_reply_header::<T>(req, payload.len())?; 468 self.endpoint 469 .send_message_with_payload(&hdr, msg, payload, None)?; 470 Ok(()) 471 } 472 473 /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a 474 /// Vring number and an fd. handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>475 pub fn handle_vring_fd_request( 476 &mut self, 477 buf: &[u8], 478 files: Option<Vec<File>>, 479 ) -> Result<(u8, Option<File>)> { 480 if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { 481 return Err(Error::InvalidMessage); 482 } 483 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; 484 if !msg.is_valid() { 485 return Err(Error::InvalidMessage); 486 } 487 488 // Virtio-vhost-user protocol doesn't send FDs. 489 if self.protocol == Protocol::Virtio { 490 return Ok((msg.value as u8, None)); 491 } 492 493 // Bits (0-7) of the payload contain the vring index. Bit 8 is the 494 // invalid FD flag (VHOST_USER_VRING_NOFD_MASK). 495 // This bit is set when there is no file descriptor 496 // in the ancillary data. This signals that polling will be used 497 // instead of waiting for the call. 498 // If Bit 8 is unset, the data must contain a file descriptor. 499 let has_fd = (msg.value & 0x100u64) == 0; 500 501 let file = take_single_file(files); 502 503 if has_fd && file.is_none() || !has_fd && file.is_some() { 504 return Err(Error::InvalidMessage); 505 } 506 507 Ok((msg.value as u8, file)) 508 } 509 } 510 511 impl<E: Endpoint<MasterReq>> AsRef<E> for SlaveReqHelper<E> { as_ref(&self) -> &E512 fn as_ref(&self) -> &E { 513 &self.endpoint 514 } 515 } 516 517 impl<E: Endpoint<MasterReq>> AsMut<E> for SlaveReqHelper<E> { as_mut(&mut self) -> &mut E518 fn as_mut(&mut self) -> &mut E { 519 &mut self.endpoint 520 } 521 } 522 523 impl<E: Endpoint<MasterReq> + AsRawDescriptor> AsRawDescriptor for SlaveReqHelper<E> { as_raw_descriptor(&self) -> RawDescriptor524 fn as_raw_descriptor(&self) -> RawDescriptor { 525 self.endpoint.as_raw_descriptor() 526 } 527 } 528 529 /// Server to handle service requests from masters from the master communication channel. 530 /// 531 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from 532 /// masters on the master communication channel. It's actually a proxy invoking the registered 533 /// handler implementing [VhostUserSlaveReqHandler] to do the real work. 534 /// 535 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain 536 /// Socket, so it gets simpler to recover from disconnect. 537 /// 538 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 539 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 540 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> { 541 slave_req_helper: SlaveReqHelper<E>, 542 // the vhost-user backend device object 543 backend: S, 544 545 virtio_features: u64, 546 acked_virtio_features: u64, 547 protocol_features: VhostUserProtocolFeatures, 548 acked_protocol_features: u64, 549 } 550 551 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S, MasterReqEndpoint> { 552 /// Create a vhost-user slave endpoint from a connected socket. from_stream(socket: SystemStream, backend: S) -> Self553 pub fn from_stream(socket: SystemStream, backend: S) -> Self { 554 Self::new(MasterReqEndpoint::from(socket), backend) 555 } 556 } 557 558 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> AsRef<S> for SlaveReqHandler<S, E> { as_ref(&self) -> &S559 fn as_ref(&self) -> &S { 560 &self.backend 561 } 562 } 563 564 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E> { 565 /// Create a vhost-user slave endpoint. new(endpoint: E, backend: S) -> Self566 pub fn new(endpoint: E, backend: S) -> Self { 567 SlaveReqHandler { 568 slave_req_helper: SlaveReqHelper::new(endpoint, backend.protocol()), 569 backend, 570 virtio_features: 0, 571 acked_virtio_features: 0, 572 protocol_features: VhostUserProtocolFeatures::empty(), 573 acked_protocol_features: 0, 574 } 575 } 576 577 /// Receives and validates a vhost-user message header and optional files. 578 /// 579 /// Since the length of vhost-user messages are different among message types, regular 580 /// vhost-user messages are sent via an underlying communication channel in stream mode. 581 /// (e.g. `SOCK_STREAM` in UNIX) 582 /// So, the logic of receiving and handling a message consists of the following steps: 583 /// 584 /// 1. Receives a message header and optional attached file. 585 /// 2. Validates the message header. 586 /// 3. Check if optional payloads is expected. 587 /// 4. Wait for the optional payloads. 588 /// 5. Receives optional payloads. 589 /// 6. Processes the message. 590 /// 591 /// This method [`SlaveReqHandler::recv_header()`] is in charge of the step (1) and (2), 592 /// [`SlaveReqHandler::needs_wait_for_payload()`] is (3), and 593 /// [`SlaveReqHandler::process_message()`] is (5) and (6). 594 /// We need to have the three method separately for multi-platform supports; 595 /// [`SlaveReqHandler::recv_header()`] and [`SlaveReqHandler::process_message()`] need to be 596 /// separated because the way of waiting for incoming messages differs between Unix and Windows 597 /// so it's the caller's responsibility to wait before [`SlaveReqHandler::process_message()`]. 598 /// 599 /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this 600 /// case, a message header and its body are sent together so the step (4) is skipped. We handle 601 /// this case in [`SlaveReqHandler::needs_wait_for_payload()`]. 602 /// 603 /// The following pseudo code describes how a caller should process incoming vhost-user 604 /// messages: 605 /// ```ignore 606 /// loop { 607 /// // block until a message header comes. 608 /// // The actual code differs, depending on platforms. 609 /// connection.wait_readable().unwrap(); 610 /// 611 /// // (1) and (2) 612 /// let (hdr, files) = slave_req_handler.recv_header(); 613 /// 614 /// // (3) 615 /// if slave_req_handler.needs_wait_for_payload(&hdr) { 616 /// // (4) block until a payload comes if needed. 617 /// connection.wait_readable().unwrap(); 618 /// } 619 /// 620 /// // (5) and (6) 621 /// slave_req_handler.process_message(&hdr, &files).unwrap(); 622 /// } 623 /// ``` recv_header(&mut self) -> Result<(VhostUserMsgHeader<MasterReq>, Option<Vec<File>>)>624 pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<MasterReq>, Option<Vec<File>>)> { 625 // The underlying communication channel is a Unix domain socket in 626 // stream mode, and recvmsg() is a little tricky here. To successfully 627 // receive attached file descriptors, we need to receive messages and 628 // corresponding attached file descriptors in this way: 629 // . recv messsage header and optional attached file 630 // . validate message header 631 // . recv optional message body and payload according size field in 632 // message header 633 // . validate message body and optional payload 634 let (hdr, files) = match self.slave_req_helper.endpoint.recv_header() { 635 Ok((hdr, files)) => (hdr, files), 636 Err(Error::Disconnect) => { 637 // If the client closed the connection before sending a header, this should be 638 // handled as a legal exit. 639 return Err(Error::ClientExit); 640 } 641 Err(e) => { 642 return Err(e); 643 } 644 }; 645 646 self.check_attached_files(&hdr, &files)?; 647 648 Ok((hdr, files)) 649 } 650 651 /// Returns whether the caller needs to wait for the incoming message before calling 652 /// [`SlaveReqHandler::process_message`]. 653 /// 654 /// See [`SlaveReqHandler::recv_header`]'s doc comment for the usage. needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<MasterReq>) -> bool655 pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<MasterReq>) -> bool { 656 // For the vhost-user protocols using stream mode, we need to wait until an additional 657 // payload is available if exists. 658 self.backend.protocol().is_stream_mode() && hdr.get_size() != 0 659 } 660 661 /// Main entrance to request from the communication channel. 662 /// 663 /// Receive and handle one incoming request message from the vmm. 664 /// See [`SlaveReqHandler::recv_header`]'s doc comment for the usage. 665 /// 666 /// # Return: 667 /// * - `Ok(())`: one request was successfully handled. 668 /// * - `Err(ClientExit)`: the vmm closed the connection properly. This isn't an actual failure. 669 /// * - `Err(Disconnect)`: the connection was closed unexpectedly. 670 /// * - `Err(InvalidMessage)`: the vmm sent a illegal message. 671 /// * - other errors: failed to handle a request. process_message( &mut self, hdr: VhostUserMsgHeader<MasterReq>, files: Option<Vec<File>>, ) -> Result<()>672 pub fn process_message( 673 &mut self, 674 hdr: VhostUserMsgHeader<MasterReq>, 675 files: Option<Vec<File>>, 676 ) -> Result<()> { 677 let buf = match hdr.get_size() { 678 0 => vec![0u8; 0], 679 len => { 680 let rbuf = self.slave_req_helper.endpoint.recv_data(len as usize)?; 681 if rbuf.len() != len as usize { 682 return Err(Error::InvalidMessage); 683 } 684 rbuf 685 } 686 }; 687 let size = buf.len(); 688 689 match hdr.get_code() { 690 MasterReq::SET_OWNER => { 691 self.check_request_size(&hdr, size, 0)?; 692 let res = self.backend.set_owner(); 693 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 694 res?; 695 } 696 MasterReq::RESET_OWNER => { 697 self.check_request_size(&hdr, size, 0)?; 698 let res = self.backend.reset_owner(); 699 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 700 res?; 701 } 702 MasterReq::GET_FEATURES => { 703 self.check_request_size(&hdr, size, 0)?; 704 let features = self.backend.get_features()?; 705 let msg = VhostUserU64::new(features); 706 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 707 self.virtio_features = features; 708 self.update_reply_ack_flag(); 709 } 710 MasterReq::SET_FEATURES => { 711 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 712 let res = self.backend.set_features(msg.value); 713 self.acked_virtio_features = msg.value; 714 self.update_reply_ack_flag(); 715 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 716 res?; 717 } 718 MasterReq::SET_MEM_TABLE => { 719 let res = self.set_mem_table(&hdr, size, &buf, files); 720 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 721 res?; 722 } 723 MasterReq::SET_VRING_NUM => { 724 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 725 let res = self.backend.set_vring_num(msg.index, msg.num); 726 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 727 res?; 728 } 729 MasterReq::SET_VRING_ADDR => { 730 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; 731 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { 732 Some(val) => val, 733 None => return Err(Error::InvalidMessage), 734 }; 735 let res = self.backend.set_vring_addr( 736 msg.index, 737 flags, 738 msg.descriptor, 739 msg.used, 740 msg.available, 741 msg.log, 742 ); 743 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 744 res?; 745 } 746 MasterReq::SET_VRING_BASE => { 747 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 748 let res = self.backend.set_vring_base(msg.index, msg.num); 749 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 750 res?; 751 } 752 MasterReq::GET_VRING_BASE => { 753 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 754 let reply = self.backend.get_vring_base(msg.index)?; 755 self.slave_req_helper.send_reply_message(&hdr, &reply)?; 756 } 757 MasterReq::SET_VRING_CALL => { 758 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 759 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 760 let res = self.backend.set_vring_call(index, file); 761 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 762 res?; 763 } 764 MasterReq::SET_VRING_KICK => { 765 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 766 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 767 let res = self.backend.set_vring_kick(index, file); 768 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 769 res?; 770 } 771 MasterReq::SET_VRING_ERR => { 772 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 773 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 774 let res = self.backend.set_vring_err(index, file); 775 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 776 res?; 777 } 778 MasterReq::GET_PROTOCOL_FEATURES => { 779 self.check_request_size(&hdr, size, 0)?; 780 let features = self.backend.get_protocol_features()?; 781 let msg = VhostUserU64::new(features.bits()); 782 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 783 self.protocol_features = features; 784 self.update_reply_ack_flag(); 785 } 786 MasterReq::SET_PROTOCOL_FEATURES => { 787 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 788 let res = self.backend.set_protocol_features(msg.value); 789 self.acked_protocol_features = msg.value; 790 self.update_reply_ack_flag(); 791 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 792 res?; 793 } 794 MasterReq::GET_QUEUE_NUM => { 795 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { 796 return Err(Error::InvalidOperation); 797 } 798 self.check_request_size(&hdr, size, 0)?; 799 let num = self.backend.get_queue_num()?; 800 let msg = VhostUserU64::new(num); 801 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 802 } 803 MasterReq::SET_VRING_ENABLE => { 804 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 805 if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() 806 == 0 807 { 808 return Err(Error::InvalidOperation); 809 } 810 let enable = match msg.num { 811 1 => true, 812 0 => false, 813 _ => return Err(Error::InvalidParam), 814 }; 815 816 let res = self.backend.set_vring_enable(msg.index, enable); 817 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 818 res?; 819 } 820 MasterReq::GET_CONFIG => { 821 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 822 return Err(Error::InvalidOperation); 823 } 824 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 825 self.get_config(&hdr, &buf)?; 826 } 827 MasterReq::SET_CONFIG => { 828 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { 829 return Err(Error::InvalidOperation); 830 } 831 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 832 let res = self.set_config(size, &buf); 833 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 834 res?; 835 } 836 MasterReq::SET_SLAVE_REQ_FD => { 837 if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { 838 return Err(Error::InvalidOperation); 839 } 840 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 841 let res = self.set_slave_req_fd(files); 842 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 843 res?; 844 } 845 MasterReq::GET_INFLIGHT_FD => { 846 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 847 == 0 848 { 849 return Err(Error::InvalidOperation); 850 } 851 852 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 853 let (inflight, file) = self.backend.get_inflight_fd(&msg)?; 854 let reply_hdr = self 855 .slave_req_helper 856 .new_reply_header::<VhostUserInflight>(&hdr, 0)?; 857 self.slave_req_helper.endpoint.send_message( 858 &reply_hdr, 859 &inflight, 860 Some(&[file.as_raw_descriptor()]), 861 )?; 862 } 863 MasterReq::SET_INFLIGHT_FD => { 864 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() 865 == 0 866 { 867 return Err(Error::InvalidOperation); 868 } 869 let file = take_single_file(files).ok_or(Error::IncorrectFds)?; 870 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 871 let res = self.backend.set_inflight_fd(&msg, file); 872 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 873 res?; 874 } 875 MasterReq::GET_MAX_MEM_SLOTS => { 876 if self.acked_protocol_features 877 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 878 == 0 879 { 880 return Err(Error::InvalidOperation); 881 } 882 self.check_request_size(&hdr, size, 0)?; 883 let num = self.backend.get_max_mem_slots()?; 884 let msg = VhostUserU64::new(num); 885 self.slave_req_helper.send_reply_message(&hdr, &msg)?; 886 } 887 MasterReq::ADD_MEM_REG => { 888 if self.acked_protocol_features 889 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 890 == 0 891 { 892 return Err(Error::InvalidOperation); 893 } 894 let mut files = files.ok_or(Error::InvalidParam)?; 895 if files.len() != 1 { 896 return Err(Error::InvalidParam); 897 } 898 let msg = 899 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 900 let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); 901 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 902 res?; 903 } 904 MasterReq::REM_MEM_REG => { 905 if self.acked_protocol_features 906 & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() 907 == 0 908 { 909 return Err(Error::InvalidOperation); 910 } 911 912 let msg = 913 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 914 let res = self.backend.remove_mem_region(&msg); 915 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?; 916 res?; 917 } 918 MasterReq::GET_SHARED_MEMORY_REGIONS => { 919 let regions = self.backend.get_shared_memory_regions()?; 920 let mut buf = Vec::new(); 921 let msg = VhostUserU64::new(regions.len() as u64); 922 for r in regions { 923 buf.extend_from_slice(r.as_bytes()) 924 } 925 self.slave_req_helper 926 .send_reply_with_payload(&hdr, &msg, buf.as_slice())?; 927 } 928 _ => { 929 return Err(Error::InvalidMessage); 930 } 931 } 932 Ok(()) 933 } 934 set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>935 fn set_mem_table( 936 &mut self, 937 hdr: &VhostUserMsgHeader<MasterReq>, 938 size: usize, 939 buf: &[u8], 940 files: Option<Vec<File>>, 941 ) -> Result<()> { 942 self.check_request_size(hdr, size, hdr.get_size() as usize)?; 943 944 // check message size is consistent 945 let hdrsize = mem::size_of::<VhostUserMemory>(); 946 if size < hdrsize { 947 return Err(Error::InvalidMessage); 948 } 949 let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; 950 if !msg.is_valid() { 951 return Err(Error::InvalidMessage); 952 } 953 if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { 954 return Err(Error::InvalidMessage); 955 } 956 957 let files = match self.slave_req_helper.protocol { 958 Protocol::Regular => { 959 // validate number of fds matching number of memory regions 960 let files = files.ok_or(Error::InvalidMessage)?; 961 if files.len() != msg.num_regions as usize { 962 return Err(Error::InvalidMessage); 963 } 964 files 965 } 966 Protocol::Virtio => vec![], 967 }; 968 969 // Validate memory regions 970 let regions = unsafe { 971 slice::from_raw_parts( 972 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, 973 msg.num_regions as usize, 974 ) 975 }; 976 for region in regions.iter() { 977 if !region.is_valid() { 978 return Err(Error::InvalidMessage); 979 } 980 } 981 982 self.backend.set_mem_table(regions, files) 983 } 984 get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>985 fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { 986 let payload_offset = mem::size_of::<VhostUserConfig>(); 987 if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { 988 return Err(Error::InvalidMessage); 989 } 990 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 991 if !msg.is_valid() { 992 return Err(Error::InvalidMessage); 993 } 994 if buf.len() - payload_offset != msg.size as usize { 995 return Err(Error::InvalidMessage); 996 } 997 let flags = match VhostUserConfigFlags::from_bits(msg.flags) { 998 Some(val) => val, 999 None => return Err(Error::InvalidMessage), 1000 }; 1001 let res = self.backend.get_config(msg.offset, msg.size, flags); 1002 1003 // vhost-user slave's payload size MUST match master's request 1004 // on success, uses zero length of payload to indicate an error 1005 // to vhost-user master. 1006 match res { 1007 Ok(ref buf) if buf.len() == msg.size as usize => { 1008 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); 1009 self.slave_req_helper 1010 .send_reply_with_payload(hdr, &reply, buf.as_slice())?; 1011 } 1012 Ok(_) => { 1013 let reply = VhostUserConfig::new(msg.offset, 0, flags); 1014 self.slave_req_helper.send_reply_message(hdr, &reply)?; 1015 } 1016 Err(_) => { 1017 let reply = VhostUserConfig::new(msg.offset, 0, flags); 1018 self.slave_req_helper.send_reply_message(hdr, &reply)?; 1019 } 1020 } 1021 Ok(()) 1022 } 1023 set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>1024 fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { 1025 if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { 1026 return Err(Error::InvalidMessage); 1027 } 1028 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 1029 if !msg.is_valid() { 1030 return Err(Error::InvalidMessage); 1031 } 1032 if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { 1033 return Err(Error::InvalidMessage); 1034 } 1035 let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) { 1036 Some(val) => val, 1037 None => return Err(Error::InvalidMessage), 1038 }; 1039 1040 self.backend.set_config(msg.offset, buf, flags) 1041 } 1042 set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>1043 fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { 1044 let ep = self 1045 .slave_req_helper 1046 .endpoint 1047 .create_slave_request_endpoint(files)?; 1048 self.backend.set_slave_req_fd(ep); 1049 Ok(()) 1050 } 1051 handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>1052 fn handle_vring_fd_request( 1053 &mut self, 1054 buf: &[u8], 1055 files: Option<Vec<File>>, 1056 ) -> Result<(u8, Option<File>)> { 1057 self.slave_req_helper.handle_vring_fd_request(buf, files) 1058 } 1059 check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>1060 fn check_request_size( 1061 &self, 1062 hdr: &VhostUserMsgHeader<MasterReq>, 1063 size: usize, 1064 expected: usize, 1065 ) -> Result<()> { 1066 if hdr.get_size() as usize != expected 1067 || hdr.is_reply() 1068 || hdr.get_version() != 0x1 1069 || size != expected 1070 { 1071 return Err(Error::InvalidMessage); 1072 } 1073 Ok(()) 1074 } 1075 check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>1076 fn check_attached_files( 1077 &self, 1078 hdr: &VhostUserMsgHeader<MasterReq>, 1079 files: &Option<Vec<File>>, 1080 ) -> Result<()> { 1081 match hdr.get_code() { 1082 MasterReq::SET_MEM_TABLE 1083 | MasterReq::SET_VRING_CALL 1084 | MasterReq::SET_VRING_KICK 1085 | MasterReq::SET_VRING_ERR 1086 | MasterReq::SET_LOG_BASE 1087 | MasterReq::SET_LOG_FD 1088 | MasterReq::SET_SLAVE_REQ_FD 1089 | MasterReq::SET_INFLIGHT_FD 1090 | MasterReq::ADD_MEM_REG => Ok(()), 1091 _ if files.is_some() => Err(Error::InvalidMessage), 1092 _ => Ok(()), 1093 } 1094 } 1095 extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>1096 fn extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( 1097 &self, 1098 hdr: &VhostUserMsgHeader<MasterReq>, 1099 size: usize, 1100 buf: &[u8], 1101 ) -> Result<T> { 1102 self.check_request_size(hdr, size, mem::size_of::<T>())?; 1103 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 1104 if !msg.is_valid() { 1105 return Err(Error::InvalidMessage); 1106 } 1107 Ok(msg) 1108 } 1109 update_reply_ack_flag(&mut self)1110 fn update_reply_ack_flag(&mut self) { 1111 let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); 1112 let pflag = VhostUserProtocolFeatures::REPLY_ACK; 1113 if (self.virtio_features & vflag) != 0 1114 && self.protocol_features.contains(pflag) 1115 && (self.acked_protocol_features & pflag.bits()) != 0 1116 { 1117 self.slave_req_helper.reply_ack_enabled = true; 1118 } else { 1119 self.slave_req_helper.reply_ack_enabled = false; 1120 } 1121 } 1122 } 1123 1124 impl<S: VhostUserSlaveReqHandler, E: AsRawDescriptor + Endpoint<MasterReq>> AsRawDescriptor 1125 for SlaveReqHandler<S, E> 1126 { as_raw_descriptor(&self) -> RawDescriptor1127 fn as_raw_descriptor(&self) -> RawDescriptor { 1128 // TODO(b/221882601): figure out if this used for polling. 1129 self.slave_req_helper.endpoint.as_raw_descriptor() 1130 } 1131 } 1132 1133 #[cfg(test)] 1134 mod tests { 1135 use base::INVALID_DESCRIPTOR; 1136 1137 use super::*; 1138 use crate::dummy_slave::DummySlaveReqHandler; 1139 use crate::MasterReqEndpoint; 1140 use crate::SystemStream; 1141 1142 #[test] test_slave_req_handler_new()1143 fn test_slave_req_handler_new() { 1144 let (p1, _p2) = SystemStream::pair().unwrap(); 1145 let endpoint = MasterReqEndpoint::from(p1); 1146 let backend = Mutex::new(DummySlaveReqHandler::new()); 1147 let handler = SlaveReqHandler::new(endpoint, backend); 1148 1149 assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR); 1150 } 1151 } 1152