• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 use std::slice;
7 use std::sync::Mutex;
8 
9 use base::AsRawDescriptor;
10 use base::RawDescriptor;
11 use data_model::DataInit;
12 use zerocopy::AsBytes;
13 
14 use crate::connection::Endpoint;
15 use crate::connection::EndpointExt;
16 use crate::message::*;
17 use crate::take_single_file;
18 use crate::Error;
19 use crate::MasterReqEndpoint;
20 use crate::Result;
21 use crate::SystemStream;
22 
23 #[derive(PartialEq, Eq, Debug)]
24 /// Vhost-user protocol variants used for the communication.
25 pub enum Protocol {
26     /// Use the regular vhost-user protocol.
27     Regular,
28     /// Use the virtio-vhost-user protocol, which is proxied through virtqueues.
29     /// The protocol is mostly same as the vhost-user protocol but no file transfer is allowed.
30     Virtio,
31 }
32 
33 impl Protocol {
34     /// Returns whether the protocol assumes that messages are sent in stream mode like Unix's SOCK_STREAM.
35     ///
36     /// In stream mode, the receivers cannot know the size of the entire message in advance so a
37     /// message header with the body size and the message body will be sent separately. See
38     /// [`SlaveReqHandler::recv_header()`]'s doc comment for more details.
is_stream_mode(&self) -> bool39     fn is_stream_mode(&self) -> bool {
40         match self {
41             Protocol::Regular => true,
42             // VVU proxy sends a message header and its payload at once.
43             Protocol::Virtio => false,
44         }
45     }
46 }
47 
48 /// Services provided to the master by the slave with interior mutability.
49 ///
50 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
51 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
52 /// but without interior mutability.
53 /// The vhost-user specification defines a master communication channel, by which masters could
54 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
55 /// slaves, and it's used both on the master side and slave side.
56 ///
57 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
58 ///   service requests to slaves.
59 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
60 ///   implementing [VhostUserSlaveReqHandler].
61 ///
62 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
63 /// for multi-threading.
64 ///
65 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
66 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
67 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
68 #[allow(missing_docs)]
69 pub trait VhostUserSlaveReqHandler {
70     /// Returns the type of vhost-user protocol that the handler support.
protocol(&self) -> Protocol71     fn protocol(&self) -> Protocol;
72 
set_owner(&self) -> Result<()>73     fn set_owner(&self) -> Result<()>;
reset_owner(&self) -> Result<()>74     fn reset_owner(&self) -> Result<()>;
get_features(&self) -> Result<u64>75     fn get_features(&self) -> Result<u64>;
set_features(&self, features: u64) -> Result<()>76     fn set_features(&self, features: u64) -> Result<()>;
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>77     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&self, index: u32, num: u32) -> Result<()>78     fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>79     fn set_vring_addr(
80         &self,
81         index: u32,
82         flags: VhostUserVringAddrFlags,
83         descriptor: u64,
84         used: u64,
85         available: u64,
86         log: u64,
87     ) -> Result<()>;
set_vring_base(&self, index: u32, base: u32) -> Result<()>88     fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>89     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>90     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>91     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>92     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
93 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>94     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&self, features: u64) -> Result<()>95     fn set_protocol_features(&self, features: u64) -> Result<()>;
get_queue_num(&self) -> Result<u64>96     fn get_queue_num(&self) -> Result<u64>;
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>97     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>98     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>99     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&self, _vu_req: Box<dyn Endpoint<SlaveReq>>)100     fn set_slave_req_fd(&self, _vu_req: Box<dyn Endpoint<SlaveReq>>) {}
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>101     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>102     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&self) -> Result<u64>103     fn get_max_mem_slots(&self) -> Result<u64>;
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>104     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>105     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>106     fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>;
107 }
108 
109 /// Services provided to the master by the slave without interior mutability.
110 ///
111 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
112 #[allow(missing_docs)]
113 pub trait VhostUserSlaveReqHandlerMut {
114     /// Returns the type of vhost-user protocol that the handler support.
protocol(&self) -> Protocol115     fn protocol(&self) -> Protocol;
116 
set_owner(&mut self) -> Result<()>117     fn set_owner(&mut self) -> Result<()>;
reset_owner(&mut self) -> Result<()>118     fn reset_owner(&mut self) -> Result<()>;
get_features(&mut self) -> Result<u64>119     fn get_features(&mut self) -> Result<u64>;
set_features(&mut self, features: u64) -> Result<()>120     fn set_features(&mut self, features: u64) -> Result<()>;
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>121     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>122     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>123     fn set_vring_addr(
124         &mut self,
125         index: u32,
126         flags: VhostUserVringAddrFlags,
127         descriptor: u64,
128         used: u64,
129         available: u64,
130         log: u64,
131     ) -> Result<()>;
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>132     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>133     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>134     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>135     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>136     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
137 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>138     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&mut self, features: u64) -> Result<()>139     fn set_protocol_features(&mut self, features: u64) -> Result<()>;
get_queue_num(&mut self) -> Result<u64>140     fn get_queue_num(&mut self) -> Result<u64>;
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>141     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>142     fn get_config(
143         &mut self,
144         offset: u32,
145         size: u32,
146         flags: VhostUserConfigFlags,
147     ) -> Result<Vec<u8>>;
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>148     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&mut self, _vu_req: Box<dyn Endpoint<SlaveReq>>)149     fn set_slave_req_fd(&mut self, _vu_req: Box<dyn Endpoint<SlaveReq>>) {}
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>150     fn get_inflight_fd(
151         &mut self,
152         inflight: &VhostUserInflight,
153     ) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>154     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&mut self) -> Result<u64>155     fn get_max_mem_slots(&mut self) -> Result<u64>;
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>156     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>157     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>158     fn get_shared_memory_regions(&mut self) -> Result<Vec<VhostSharedMemoryRegion>>;
159 }
160 
161 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
protocol(&self) -> Protocol162     fn protocol(&self) -> Protocol {
163         self.lock().unwrap().protocol()
164     }
165 
set_owner(&self) -> Result<()>166     fn set_owner(&self) -> Result<()> {
167         self.lock().unwrap().set_owner()
168     }
169 
reset_owner(&self) -> Result<()>170     fn reset_owner(&self) -> Result<()> {
171         self.lock().unwrap().reset_owner()
172     }
173 
get_features(&self) -> Result<u64>174     fn get_features(&self) -> Result<u64> {
175         self.lock().unwrap().get_features()
176     }
177 
set_features(&self, features: u64) -> Result<()>178     fn set_features(&self, features: u64) -> Result<()> {
179         self.lock().unwrap().set_features(features)
180     }
181 
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>182     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
183         self.lock().unwrap().set_mem_table(ctx, files)
184     }
185 
set_vring_num(&self, index: u32, num: u32) -> Result<()>186     fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
187         self.lock().unwrap().set_vring_num(index, num)
188     }
189 
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>190     fn set_vring_addr(
191         &self,
192         index: u32,
193         flags: VhostUserVringAddrFlags,
194         descriptor: u64,
195         used: u64,
196         available: u64,
197         log: u64,
198     ) -> Result<()> {
199         self.lock()
200             .unwrap()
201             .set_vring_addr(index, flags, descriptor, used, available, log)
202     }
203 
set_vring_base(&self, index: u32, base: u32) -> Result<()>204     fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
205         self.lock().unwrap().set_vring_base(index, base)
206     }
207 
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>208     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
209         self.lock().unwrap().get_vring_base(index)
210     }
211 
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>212     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
213         self.lock().unwrap().set_vring_kick(index, fd)
214     }
215 
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>216     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
217         self.lock().unwrap().set_vring_call(index, fd)
218     }
219 
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>220     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
221         self.lock().unwrap().set_vring_err(index, fd)
222     }
223 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>224     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
225         self.lock().unwrap().get_protocol_features()
226     }
227 
set_protocol_features(&self, features: u64) -> Result<()>228     fn set_protocol_features(&self, features: u64) -> Result<()> {
229         self.lock().unwrap().set_protocol_features(features)
230     }
231 
get_queue_num(&self) -> Result<u64>232     fn get_queue_num(&self) -> Result<u64> {
233         self.lock().unwrap().get_queue_num()
234     }
235 
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>236     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
237         self.lock().unwrap().set_vring_enable(index, enable)
238     }
239 
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>240     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
241         self.lock().unwrap().get_config(offset, size, flags)
242     }
243 
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>244     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
245         self.lock().unwrap().set_config(offset, buf, flags)
246     }
247 
set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>)248     fn set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>) {
249         self.lock().unwrap().set_slave_req_fd(vu_req)
250     }
251 
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>252     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
253         self.lock().unwrap().get_inflight_fd(inflight)
254     }
255 
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>256     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
257         self.lock().unwrap().set_inflight_fd(inflight, file)
258     }
259 
get_max_mem_slots(&self) -> Result<u64>260     fn get_max_mem_slots(&self) -> Result<u64> {
261         self.lock().unwrap().get_max_mem_slots()
262     }
263 
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>264     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
265         self.lock().unwrap().add_mem_region(region, fd)
266     }
267 
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>268     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
269         self.lock().unwrap().remove_mem_region(region)
270     }
271 
get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>272     fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>> {
273         self.lock().unwrap().get_shared_memory_regions()
274     }
275 }
276 
277 impl<T> VhostUserSlaveReqHandler for T
278 where
279     T: AsRef<dyn VhostUserSlaveReqHandler>,
280 {
protocol(&self) -> Protocol281     fn protocol(&self) -> Protocol {
282         self.as_ref().protocol()
283     }
284 
set_owner(&self) -> Result<()>285     fn set_owner(&self) -> Result<()> {
286         self.as_ref().set_owner()
287     }
288 
reset_owner(&self) -> Result<()>289     fn reset_owner(&self) -> Result<()> {
290         self.as_ref().reset_owner()
291     }
292 
get_features(&self) -> Result<u64>293     fn get_features(&self) -> Result<u64> {
294         self.as_ref().get_features()
295     }
296 
set_features(&self, features: u64) -> Result<()>297     fn set_features(&self, features: u64) -> Result<()> {
298         self.as_ref().set_features(features)
299     }
300 
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>301     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
302         self.as_ref().set_mem_table(ctx, files)
303     }
304 
set_vring_num(&self, index: u32, num: u32) -> Result<()>305     fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
306         self.as_ref().set_vring_num(index, num)
307     }
308 
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>309     fn set_vring_addr(
310         &self,
311         index: u32,
312         flags: VhostUserVringAddrFlags,
313         descriptor: u64,
314         used: u64,
315         available: u64,
316         log: u64,
317     ) -> Result<()> {
318         self.as_ref()
319             .set_vring_addr(index, flags, descriptor, used, available, log)
320     }
321 
set_vring_base(&self, index: u32, base: u32) -> Result<()>322     fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
323         self.as_ref().set_vring_base(index, base)
324     }
325 
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>326     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
327         self.as_ref().get_vring_base(index)
328     }
329 
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>330     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
331         self.as_ref().set_vring_kick(index, fd)
332     }
333 
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>334     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
335         self.as_ref().set_vring_call(index, fd)
336     }
337 
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>338     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
339         self.as_ref().set_vring_err(index, fd)
340     }
341 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>342     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
343         self.as_ref().get_protocol_features()
344     }
345 
set_protocol_features(&self, features: u64) -> Result<()>346     fn set_protocol_features(&self, features: u64) -> Result<()> {
347         self.as_ref().set_protocol_features(features)
348     }
349 
get_queue_num(&self) -> Result<u64>350     fn get_queue_num(&self) -> Result<u64> {
351         self.as_ref().get_queue_num()
352     }
353 
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>354     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
355         self.as_ref().set_vring_enable(index, enable)
356     }
357 
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>358     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
359         self.as_ref().get_config(offset, size, flags)
360     }
361 
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>362     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
363         self.as_ref().set_config(offset, buf, flags)
364     }
365 
set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>)366     fn set_slave_req_fd(&self, vu_req: Box<dyn Endpoint<SlaveReq>>) {
367         self.as_ref().set_slave_req_fd(vu_req)
368     }
369 
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>370     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
371         self.as_ref().get_inflight_fd(inflight)
372     }
373 
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>374     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
375         self.as_ref().set_inflight_fd(inflight, file)
376     }
377 
get_max_mem_slots(&self) -> Result<u64>378     fn get_max_mem_slots(&self) -> Result<u64> {
379         self.as_ref().get_max_mem_slots()
380     }
381 
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>382     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
383         self.as_ref().add_mem_region(region, fd)
384     }
385 
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>386     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
387         self.as_ref().remove_mem_region(region)
388     }
389 
get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>390     fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>> {
391         self.as_ref().get_shared_memory_regions()
392     }
393 }
394 
395 /// Abstracts |Endpoint| related operations for vhost-user slave implementations.
396 pub struct SlaveReqHelper<E: Endpoint<MasterReq>> {
397     /// Underlying endpoint for communication.
398     endpoint: E,
399 
400     /// Protocol used for the communication.
401     protocol: Protocol,
402 
403     /// Sending ack for messages without payload.
404     reply_ack_enabled: bool,
405 }
406 
407 impl<E: Endpoint<MasterReq>> SlaveReqHelper<E> {
408     /// Creates a new |SlaveReqHelper| instance with an |Endpoint| underneath it.
new(endpoint: E, protocol: Protocol) -> Self409     pub fn new(endpoint: E, protocol: Protocol) -> Self {
410         SlaveReqHelper {
411             endpoint,
412             protocol,
413             reply_ack_enabled: false,
414         }
415     }
416 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>417     fn new_reply_header<T: Sized>(
418         &self,
419         req: &VhostUserMsgHeader<MasterReq>,
420         payload_size: usize,
421     ) -> Result<VhostUserMsgHeader<MasterReq>> {
422         if mem::size_of::<T>() > MAX_MSG_SIZE
423             || payload_size > MAX_MSG_SIZE
424             || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
425         {
426             return Err(Error::InvalidParam);
427         }
428 
429         Ok(VhostUserMsgHeader::new(
430             req.get_code(),
431             VhostUserHeaderFlag::REPLY.bits(),
432             (mem::size_of::<T>() + payload_size) as u32,
433         ))
434     }
435 
436     /// Sends reply back to Vhost Master in response to a message.
send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, success: bool, ) -> Result<()>437     pub fn send_ack_message(
438         &mut self,
439         req: &VhostUserMsgHeader<MasterReq>,
440         success: bool,
441     ) -> Result<()> {
442         if self.reply_ack_enabled && req.is_need_reply() {
443             let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
444             let val = if success { 0 } else { 1 };
445             let msg = VhostUserU64::new(val);
446             self.endpoint.send_message(&hdr, &msg, None)?;
447         }
448         Ok(())
449     }
450 
send_reply_message<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>451     fn send_reply_message<T: Sized + DataInit>(
452         &mut self,
453         req: &VhostUserMsgHeader<MasterReq>,
454         msg: &T,
455     ) -> Result<()> {
456         let hdr = self.new_reply_header::<T>(req, 0)?;
457         self.endpoint.send_message(&hdr, msg, None)?;
458         Ok(())
459     }
460 
send_reply_with_payload<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>461     fn send_reply_with_payload<T: Sized + DataInit>(
462         &mut self,
463         req: &VhostUserMsgHeader<MasterReq>,
464         msg: &T,
465         payload: &[u8],
466     ) -> Result<()> {
467         let hdr = self.new_reply_header::<T>(req, payload.len())?;
468         self.endpoint
469             .send_message_with_payload(&hdr, msg, payload, None)?;
470         Ok(())
471     }
472 
473     /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
474     /// Vring number and an fd.
handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>475     pub fn handle_vring_fd_request(
476         &mut self,
477         buf: &[u8],
478         files: Option<Vec<File>>,
479     ) -> Result<(u8, Option<File>)> {
480         if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
481             return Err(Error::InvalidMessage);
482         }
483         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
484         if !msg.is_valid() {
485             return Err(Error::InvalidMessage);
486         }
487 
488         // Virtio-vhost-user protocol doesn't send FDs.
489         if self.protocol == Protocol::Virtio {
490             return Ok((msg.value as u8, None));
491         }
492 
493         // Bits (0-7) of the payload contain the vring index. Bit 8 is the
494         // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
495         // This bit is set when there is no file descriptor
496         // in the ancillary data. This signals that polling will be used
497         // instead of waiting for the call.
498         // If Bit 8 is unset, the data must contain a file descriptor.
499         let has_fd = (msg.value & 0x100u64) == 0;
500 
501         let file = take_single_file(files);
502 
503         if has_fd && file.is_none() || !has_fd && file.is_some() {
504             return Err(Error::InvalidMessage);
505         }
506 
507         Ok((msg.value as u8, file))
508     }
509 }
510 
511 impl<E: Endpoint<MasterReq>> AsRef<E> for SlaveReqHelper<E> {
as_ref(&self) -> &E512     fn as_ref(&self) -> &E {
513         &self.endpoint
514     }
515 }
516 
517 impl<E: Endpoint<MasterReq>> AsMut<E> for SlaveReqHelper<E> {
as_mut(&mut self) -> &mut E518     fn as_mut(&mut self) -> &mut E {
519         &mut self.endpoint
520     }
521 }
522 
523 impl<E: Endpoint<MasterReq> + AsRawDescriptor> AsRawDescriptor for SlaveReqHelper<E> {
as_raw_descriptor(&self) -> RawDescriptor524     fn as_raw_descriptor(&self) -> RawDescriptor {
525         self.endpoint.as_raw_descriptor()
526     }
527 }
528 
529 /// Server to handle service requests from masters from the master communication channel.
530 ///
531 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
532 /// masters on the master communication channel. It's actually a proxy invoking the registered
533 /// handler implementing [VhostUserSlaveReqHandler] to do the real work.
534 ///
535 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
536 /// Socket, so it gets simpler to recover from disconnect.
537 ///
538 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
539 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
540 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> {
541     slave_req_helper: SlaveReqHelper<E>,
542     // the vhost-user backend device object
543     backend: S,
544 
545     virtio_features: u64,
546     acked_virtio_features: u64,
547     protocol_features: VhostUserProtocolFeatures,
548     acked_protocol_features: u64,
549 }
550 
551 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S, MasterReqEndpoint> {
552     /// Create a vhost-user slave endpoint from a connected socket.
from_stream(socket: SystemStream, backend: S) -> Self553     pub fn from_stream(socket: SystemStream, backend: S) -> Self {
554         Self::new(MasterReqEndpoint::from(socket), backend)
555     }
556 }
557 
558 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> AsRef<S> for SlaveReqHandler<S, E> {
as_ref(&self) -> &S559     fn as_ref(&self) -> &S {
560         &self.backend
561     }
562 }
563 
564 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E> {
565     /// Create a vhost-user slave endpoint.
new(endpoint: E, backend: S) -> Self566     pub fn new(endpoint: E, backend: S) -> Self {
567         SlaveReqHandler {
568             slave_req_helper: SlaveReqHelper::new(endpoint, backend.protocol()),
569             backend,
570             virtio_features: 0,
571             acked_virtio_features: 0,
572             protocol_features: VhostUserProtocolFeatures::empty(),
573             acked_protocol_features: 0,
574         }
575     }
576 
577     /// Receives and validates a vhost-user message header and optional files.
578     ///
579     /// Since the length of vhost-user messages are different among message types, regular
580     /// vhost-user messages are sent via an underlying communication channel in stream mode.
581     /// (e.g. `SOCK_STREAM` in UNIX)
582     /// So, the logic of receiving and handling a message consists of the following steps:
583     ///
584     /// 1. Receives a message header and optional attached file.
585     /// 2. Validates the message header.
586     /// 3. Check if optional payloads is expected.
587     /// 4. Wait for the optional payloads.
588     /// 5. Receives optional payloads.
589     /// 6. Processes the message.
590     ///
591     /// This method [`SlaveReqHandler::recv_header()`] is in charge of the step (1) and (2),
592     /// [`SlaveReqHandler::needs_wait_for_payload()`] is (3), and
593     /// [`SlaveReqHandler::process_message()`] is (5) and (6).
594     /// We need to have the three method separately for multi-platform supports;
595     /// [`SlaveReqHandler::recv_header()`] and [`SlaveReqHandler::process_message()`] need to be
596     /// separated because the way of waiting for incoming messages differs between Unix and Windows
597     /// so it's the caller's responsibility to wait before [`SlaveReqHandler::process_message()`].
598     ///
599     /// Note that some vhost-user protocol variant such as VVU doesn't assume stream mode. In this
600     /// case, a message header and its body are sent together so the step (4) is skipped. We handle
601     /// this case in [`SlaveReqHandler::needs_wait_for_payload()`].
602     ///
603     /// The following pseudo code describes how a caller should process incoming vhost-user
604     /// messages:
605     /// ```ignore
606     /// loop {
607     ///   // block until a message header comes.
608     ///   // The actual code differs, depending on platforms.
609     ///   connection.wait_readable().unwrap();
610     ///
611     ///   // (1) and (2)
612     ///   let (hdr, files) = slave_req_handler.recv_header();
613     ///
614     ///   // (3)
615     ///   if slave_req_handler.needs_wait_for_payload(&hdr) {
616     ///     // (4) block until a payload comes if needed.
617     ///     connection.wait_readable().unwrap();
618     ///   }
619     ///
620     ///   // (5) and (6)
621     ///   slave_req_handler.process_message(&hdr, &files).unwrap();
622     /// }
623     /// ```
recv_header(&mut self) -> Result<(VhostUserMsgHeader<MasterReq>, Option<Vec<File>>)>624     pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<MasterReq>, Option<Vec<File>>)> {
625         // The underlying communication channel is a Unix domain socket in
626         // stream mode, and recvmsg() is a little tricky here. To successfully
627         // receive attached file descriptors, we need to receive messages and
628         // corresponding attached file descriptors in this way:
629         // . recv messsage header and optional attached file
630         // . validate message header
631         // . recv optional message body and payload according size field in
632         //   message header
633         // . validate message body and optional payload
634         let (hdr, files) = match self.slave_req_helper.endpoint.recv_header() {
635             Ok((hdr, files)) => (hdr, files),
636             Err(Error::Disconnect) => {
637                 // If the client closed the connection before sending a header, this should be
638                 // handled as a legal exit.
639                 return Err(Error::ClientExit);
640             }
641             Err(e) => {
642                 return Err(e);
643             }
644         };
645 
646         self.check_attached_files(&hdr, &files)?;
647 
648         Ok((hdr, files))
649     }
650 
651     /// Returns whether the caller needs to wait for the incoming message before calling
652     /// [`SlaveReqHandler::process_message`].
653     ///
654     /// See [`SlaveReqHandler::recv_header`]'s doc comment for the usage.
needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<MasterReq>) -> bool655     pub fn needs_wait_for_payload(&self, hdr: &VhostUserMsgHeader<MasterReq>) -> bool {
656         // For the vhost-user protocols using stream mode, we need to wait until an additional
657         // payload is available if exists.
658         self.backend.protocol().is_stream_mode() && hdr.get_size() != 0
659     }
660 
661     /// Main entrance to request from the communication channel.
662     ///
663     /// Receive and handle one incoming request message from the vmm.
664     /// See [`SlaveReqHandler::recv_header`]'s doc comment for the usage.
665     ///
666     /// # Return:
667     /// * - `Ok(())`: one request was successfully handled.
668     /// * - `Err(ClientExit)`: the vmm closed the connection properly. This isn't an actual failure.
669     /// * - `Err(Disconnect)`: the connection was closed unexpectedly.
670     /// * - `Err(InvalidMessage)`: the vmm sent a illegal message.
671     /// * - other errors: failed to handle a request.
process_message( &mut self, hdr: VhostUserMsgHeader<MasterReq>, files: Option<Vec<File>>, ) -> Result<()>672     pub fn process_message(
673         &mut self,
674         hdr: VhostUserMsgHeader<MasterReq>,
675         files: Option<Vec<File>>,
676     ) -> Result<()> {
677         let buf = match hdr.get_size() {
678             0 => vec![0u8; 0],
679             len => {
680                 let rbuf = self.slave_req_helper.endpoint.recv_data(len as usize)?;
681                 if rbuf.len() != len as usize {
682                     return Err(Error::InvalidMessage);
683                 }
684                 rbuf
685             }
686         };
687         let size = buf.len();
688 
689         match hdr.get_code() {
690             MasterReq::SET_OWNER => {
691                 self.check_request_size(&hdr, size, 0)?;
692                 let res = self.backend.set_owner();
693                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
694                 res?;
695             }
696             MasterReq::RESET_OWNER => {
697                 self.check_request_size(&hdr, size, 0)?;
698                 let res = self.backend.reset_owner();
699                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
700                 res?;
701             }
702             MasterReq::GET_FEATURES => {
703                 self.check_request_size(&hdr, size, 0)?;
704                 let features = self.backend.get_features()?;
705                 let msg = VhostUserU64::new(features);
706                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
707                 self.virtio_features = features;
708                 self.update_reply_ack_flag();
709             }
710             MasterReq::SET_FEATURES => {
711                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
712                 let res = self.backend.set_features(msg.value);
713                 self.acked_virtio_features = msg.value;
714                 self.update_reply_ack_flag();
715                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
716                 res?;
717             }
718             MasterReq::SET_MEM_TABLE => {
719                 let res = self.set_mem_table(&hdr, size, &buf, files);
720                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
721                 res?;
722             }
723             MasterReq::SET_VRING_NUM => {
724                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
725                 let res = self.backend.set_vring_num(msg.index, msg.num);
726                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
727                 res?;
728             }
729             MasterReq::SET_VRING_ADDR => {
730                 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
731                 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
732                     Some(val) => val,
733                     None => return Err(Error::InvalidMessage),
734                 };
735                 let res = self.backend.set_vring_addr(
736                     msg.index,
737                     flags,
738                     msg.descriptor,
739                     msg.used,
740                     msg.available,
741                     msg.log,
742                 );
743                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
744                 res?;
745             }
746             MasterReq::SET_VRING_BASE => {
747                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
748                 let res = self.backend.set_vring_base(msg.index, msg.num);
749                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
750                 res?;
751             }
752             MasterReq::GET_VRING_BASE => {
753                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
754                 let reply = self.backend.get_vring_base(msg.index)?;
755                 self.slave_req_helper.send_reply_message(&hdr, &reply)?;
756             }
757             MasterReq::SET_VRING_CALL => {
758                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
759                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
760                 let res = self.backend.set_vring_call(index, file);
761                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
762                 res?;
763             }
764             MasterReq::SET_VRING_KICK => {
765                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
766                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
767                 let res = self.backend.set_vring_kick(index, file);
768                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
769                 res?;
770             }
771             MasterReq::SET_VRING_ERR => {
772                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
773                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
774                 let res = self.backend.set_vring_err(index, file);
775                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
776                 res?;
777             }
778             MasterReq::GET_PROTOCOL_FEATURES => {
779                 self.check_request_size(&hdr, size, 0)?;
780                 let features = self.backend.get_protocol_features()?;
781                 let msg = VhostUserU64::new(features.bits());
782                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
783                 self.protocol_features = features;
784                 self.update_reply_ack_flag();
785             }
786             MasterReq::SET_PROTOCOL_FEATURES => {
787                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
788                 let res = self.backend.set_protocol_features(msg.value);
789                 self.acked_protocol_features = msg.value;
790                 self.update_reply_ack_flag();
791                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
792                 res?;
793             }
794             MasterReq::GET_QUEUE_NUM => {
795                 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
796                     return Err(Error::InvalidOperation);
797                 }
798                 self.check_request_size(&hdr, size, 0)?;
799                 let num = self.backend.get_queue_num()?;
800                 let msg = VhostUserU64::new(num);
801                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
802             }
803             MasterReq::SET_VRING_ENABLE => {
804                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
805                 if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
806                     == 0
807                 {
808                     return Err(Error::InvalidOperation);
809                 }
810                 let enable = match msg.num {
811                     1 => true,
812                     0 => false,
813                     _ => return Err(Error::InvalidParam),
814                 };
815 
816                 let res = self.backend.set_vring_enable(msg.index, enable);
817                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
818                 res?;
819             }
820             MasterReq::GET_CONFIG => {
821                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
822                     return Err(Error::InvalidOperation);
823                 }
824                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
825                 self.get_config(&hdr, &buf)?;
826             }
827             MasterReq::SET_CONFIG => {
828                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
829                     return Err(Error::InvalidOperation);
830                 }
831                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
832                 let res = self.set_config(size, &buf);
833                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
834                 res?;
835             }
836             MasterReq::SET_SLAVE_REQ_FD => {
837                 if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
838                     return Err(Error::InvalidOperation);
839                 }
840                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
841                 let res = self.set_slave_req_fd(files);
842                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
843                 res?;
844             }
845             MasterReq::GET_INFLIGHT_FD => {
846                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
847                     == 0
848                 {
849                     return Err(Error::InvalidOperation);
850                 }
851 
852                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
853                 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
854                 let reply_hdr = self
855                     .slave_req_helper
856                     .new_reply_header::<VhostUserInflight>(&hdr, 0)?;
857                 self.slave_req_helper.endpoint.send_message(
858                     &reply_hdr,
859                     &inflight,
860                     Some(&[file.as_raw_descriptor()]),
861                 )?;
862             }
863             MasterReq::SET_INFLIGHT_FD => {
864                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
865                     == 0
866                 {
867                     return Err(Error::InvalidOperation);
868                 }
869                 let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
870                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
871                 let res = self.backend.set_inflight_fd(&msg, file);
872                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
873                 res?;
874             }
875             MasterReq::GET_MAX_MEM_SLOTS => {
876                 if self.acked_protocol_features
877                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
878                     == 0
879                 {
880                     return Err(Error::InvalidOperation);
881                 }
882                 self.check_request_size(&hdr, size, 0)?;
883                 let num = self.backend.get_max_mem_slots()?;
884                 let msg = VhostUserU64::new(num);
885                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
886             }
887             MasterReq::ADD_MEM_REG => {
888                 if self.acked_protocol_features
889                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
890                     == 0
891                 {
892                     return Err(Error::InvalidOperation);
893                 }
894                 let mut files = files.ok_or(Error::InvalidParam)?;
895                 if files.len() != 1 {
896                     return Err(Error::InvalidParam);
897                 }
898                 let msg =
899                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
900                 let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
901                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
902                 res?;
903             }
904             MasterReq::REM_MEM_REG => {
905                 if self.acked_protocol_features
906                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
907                     == 0
908                 {
909                     return Err(Error::InvalidOperation);
910                 }
911 
912                 let msg =
913                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
914                 let res = self.backend.remove_mem_region(&msg);
915                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
916                 res?;
917             }
918             MasterReq::GET_SHARED_MEMORY_REGIONS => {
919                 let regions = self.backend.get_shared_memory_regions()?;
920                 let mut buf = Vec::new();
921                 let msg = VhostUserU64::new(regions.len() as u64);
922                 for r in regions {
923                     buf.extend_from_slice(r.as_bytes())
924                 }
925                 self.slave_req_helper
926                     .send_reply_with_payload(&hdr, &msg, buf.as_slice())?;
927             }
928             _ => {
929                 return Err(Error::InvalidMessage);
930             }
931         }
932         Ok(())
933     }
934 
set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>935     fn set_mem_table(
936         &mut self,
937         hdr: &VhostUserMsgHeader<MasterReq>,
938         size: usize,
939         buf: &[u8],
940         files: Option<Vec<File>>,
941     ) -> Result<()> {
942         self.check_request_size(hdr, size, hdr.get_size() as usize)?;
943 
944         // check message size is consistent
945         let hdrsize = mem::size_of::<VhostUserMemory>();
946         if size < hdrsize {
947             return Err(Error::InvalidMessage);
948         }
949         let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
950         if !msg.is_valid() {
951             return Err(Error::InvalidMessage);
952         }
953         if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
954             return Err(Error::InvalidMessage);
955         }
956 
957         let files = match self.slave_req_helper.protocol {
958             Protocol::Regular => {
959                 // validate number of fds matching number of memory regions
960                 let files = files.ok_or(Error::InvalidMessage)?;
961                 if files.len() != msg.num_regions as usize {
962                     return Err(Error::InvalidMessage);
963                 }
964                 files
965             }
966             Protocol::Virtio => vec![],
967         };
968 
969         // Validate memory regions
970         let regions = unsafe {
971             slice::from_raw_parts(
972                 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
973                 msg.num_regions as usize,
974             )
975         };
976         for region in regions.iter() {
977             if !region.is_valid() {
978                 return Err(Error::InvalidMessage);
979             }
980         }
981 
982         self.backend.set_mem_table(regions, files)
983     }
984 
get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>985     fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
986         let payload_offset = mem::size_of::<VhostUserConfig>();
987         if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
988             return Err(Error::InvalidMessage);
989         }
990         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
991         if !msg.is_valid() {
992             return Err(Error::InvalidMessage);
993         }
994         if buf.len() - payload_offset != msg.size as usize {
995             return Err(Error::InvalidMessage);
996         }
997         let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
998             Some(val) => val,
999             None => return Err(Error::InvalidMessage),
1000         };
1001         let res = self.backend.get_config(msg.offset, msg.size, flags);
1002 
1003         // vhost-user slave's payload size MUST match master's request
1004         // on success, uses zero length of payload to indicate an error
1005         // to vhost-user master.
1006         match res {
1007             Ok(ref buf) if buf.len() == msg.size as usize => {
1008                 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
1009                 self.slave_req_helper
1010                     .send_reply_with_payload(hdr, &reply, buf.as_slice())?;
1011             }
1012             Ok(_) => {
1013                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
1014                 self.slave_req_helper.send_reply_message(hdr, &reply)?;
1015             }
1016             Err(_) => {
1017                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
1018                 self.slave_req_helper.send_reply_message(hdr, &reply)?;
1019             }
1020         }
1021         Ok(())
1022     }
1023 
set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>1024     fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
1025         if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
1026             return Err(Error::InvalidMessage);
1027         }
1028         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
1029         if !msg.is_valid() {
1030             return Err(Error::InvalidMessage);
1031         }
1032         if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
1033             return Err(Error::InvalidMessage);
1034         }
1035         let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
1036             Some(val) => val,
1037             None => return Err(Error::InvalidMessage),
1038         };
1039 
1040         self.backend.set_config(msg.offset, buf, flags)
1041     }
1042 
set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>1043     fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
1044         let ep = self
1045             .slave_req_helper
1046             .endpoint
1047             .create_slave_request_endpoint(files)?;
1048         self.backend.set_slave_req_fd(ep);
1049         Ok(())
1050     }
1051 
handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>1052     fn handle_vring_fd_request(
1053         &mut self,
1054         buf: &[u8],
1055         files: Option<Vec<File>>,
1056     ) -> Result<(u8, Option<File>)> {
1057         self.slave_req_helper.handle_vring_fd_request(buf, files)
1058     }
1059 
check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>1060     fn check_request_size(
1061         &self,
1062         hdr: &VhostUserMsgHeader<MasterReq>,
1063         size: usize,
1064         expected: usize,
1065     ) -> Result<()> {
1066         if hdr.get_size() as usize != expected
1067             || hdr.is_reply()
1068             || hdr.get_version() != 0x1
1069             || size != expected
1070         {
1071             return Err(Error::InvalidMessage);
1072         }
1073         Ok(())
1074     }
1075 
check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>1076     fn check_attached_files(
1077         &self,
1078         hdr: &VhostUserMsgHeader<MasterReq>,
1079         files: &Option<Vec<File>>,
1080     ) -> Result<()> {
1081         match hdr.get_code() {
1082             MasterReq::SET_MEM_TABLE
1083             | MasterReq::SET_VRING_CALL
1084             | MasterReq::SET_VRING_KICK
1085             | MasterReq::SET_VRING_ERR
1086             | MasterReq::SET_LOG_BASE
1087             | MasterReq::SET_LOG_FD
1088             | MasterReq::SET_SLAVE_REQ_FD
1089             | MasterReq::SET_INFLIGHT_FD
1090             | MasterReq::ADD_MEM_REG => Ok(()),
1091             _ if files.is_some() => Err(Error::InvalidMessage),
1092             _ => Ok(()),
1093         }
1094     }
1095 
extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>1096     fn extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>(
1097         &self,
1098         hdr: &VhostUserMsgHeader<MasterReq>,
1099         size: usize,
1100         buf: &[u8],
1101     ) -> Result<T> {
1102         self.check_request_size(hdr, size, mem::size_of::<T>())?;
1103         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
1104         if !msg.is_valid() {
1105             return Err(Error::InvalidMessage);
1106         }
1107         Ok(msg)
1108     }
1109 
update_reply_ack_flag(&mut self)1110     fn update_reply_ack_flag(&mut self) {
1111         let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
1112         let pflag = VhostUserProtocolFeatures::REPLY_ACK;
1113         if (self.virtio_features & vflag) != 0
1114             && self.protocol_features.contains(pflag)
1115             && (self.acked_protocol_features & pflag.bits()) != 0
1116         {
1117             self.slave_req_helper.reply_ack_enabled = true;
1118         } else {
1119             self.slave_req_helper.reply_ack_enabled = false;
1120         }
1121     }
1122 }
1123 
1124 impl<S: VhostUserSlaveReqHandler, E: AsRawDescriptor + Endpoint<MasterReq>> AsRawDescriptor
1125     for SlaveReqHandler<S, E>
1126 {
as_raw_descriptor(&self) -> RawDescriptor1127     fn as_raw_descriptor(&self) -> RawDescriptor {
1128         // TODO(b/221882601): figure out if this used for polling.
1129         self.slave_req_helper.endpoint.as_raw_descriptor()
1130     }
1131 }
1132 
1133 #[cfg(test)]
1134 mod tests {
1135     use base::INVALID_DESCRIPTOR;
1136 
1137     use super::*;
1138     use crate::dummy_slave::DummySlaveReqHandler;
1139     use crate::MasterReqEndpoint;
1140     use crate::SystemStream;
1141 
1142     #[test]
test_slave_req_handler_new()1143     fn test_slave_req_handler_new() {
1144         let (p1, _p2) = SystemStream::pair().unwrap();
1145         let endpoint = MasterReqEndpoint::from(p1);
1146         let backend = Mutex::new(DummySlaveReqHandler::new());
1147         let handler = SlaveReqHandler::new(endpoint, backend);
1148 
1149         assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
1150     }
1151 }
1152