• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use base::{AsRawDescriptor, RawDescriptor};
5 use std::fs::File;
6 use std::mem;
7 use std::slice;
8 use std::sync::{Arc, Mutex};
9 
10 use data_model::DataInit;
11 
12 use super::connection::{Endpoint, EndpointExt};
13 use super::message::*;
14 use super::{take_single_file, Error, Result};
15 use crate::{MasterReqEndpoint, SystemStream};
16 
17 #[derive(PartialEq, Eq, Debug)]
18 /// Vhost-user protocol variants used for the communication.
19 pub enum Protocol {
20     /// Use the regular vhost-user protocol.
21     Regular,
22     /// Use the virtio-vhost-user protocol, which is proxied through virtqueues.
23     /// The protocol is mostly same as the vhost-user protocol but no file transfer is allowed.
24     Virtio,
25 }
26 
27 /// Services provided to the master by the slave with interior mutability.
28 ///
29 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
30 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
31 /// but without interior mutability.
32 /// The vhost-user specification defines a master communication channel, by which masters could
33 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
34 /// slaves, and it's used both on the master side and slave side.
35 ///
36 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
37 ///   service requests to slaves.
38 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
39 ///   implementing [VhostUserSlaveReqHandler].
40 ///
41 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
42 /// for multi-threading.
43 ///
44 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
45 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
46 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
47 #[allow(missing_docs)]
48 pub trait VhostUserSlaveReqHandler {
49     /// Returns the type of vhost-user protocol that the handler support.
protocol(&self) -> Protocol50     fn protocol(&self) -> Protocol;
51 
set_owner(&self) -> Result<()>52     fn set_owner(&self) -> Result<()>;
reset_owner(&self) -> Result<()>53     fn reset_owner(&self) -> Result<()>;
get_features(&self) -> Result<u64>54     fn get_features(&self) -> Result<u64>;
set_features(&self, features: u64) -> Result<()>55     fn set_features(&self, features: u64) -> Result<()>;
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>56     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&self, index: u32, num: u32) -> Result<()>57     fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>58     fn set_vring_addr(
59         &self,
60         index: u32,
61         flags: VhostUserVringAddrFlags,
62         descriptor: u64,
63         used: u64,
64         available: u64,
65         log: u64,
66     ) -> Result<()>;
set_vring_base(&self, index: u32, base: u32) -> Result<()>67     fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>68     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>69     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>70     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>71     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
72 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>73     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&self, features: u64) -> Result<()>74     fn set_protocol_features(&self, features: u64) -> Result<()>;
get_queue_num(&self) -> Result<u64>75     fn get_queue_num(&self) -> Result<u64>;
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>76     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>77     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>78     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&self, _vu_req: File)79     fn set_slave_req_fd(&self, _vu_req: File) {}
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>80     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>81     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&self) -> Result<u64>82     fn get_max_mem_slots(&self) -> Result<u64>;
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>83     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>84     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
85 }
86 
87 /// Services provided to the master by the slave without interior mutability.
88 ///
89 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
90 #[allow(missing_docs)]
91 pub trait VhostUserSlaveReqHandlerMut {
92     /// Returns the type of vhost-user protocol that the handler support.
protocol(&self) -> Protocol93     fn protocol(&self) -> Protocol;
94 
set_owner(&mut self) -> Result<()>95     fn set_owner(&mut self) -> Result<()>;
reset_owner(&mut self) -> Result<()>96     fn reset_owner(&mut self) -> Result<()>;
get_features(&mut self) -> Result<u64>97     fn get_features(&mut self) -> Result<u64>;
set_features(&mut self, features: u64) -> Result<()>98     fn set_features(&mut self, features: u64) -> Result<()>;
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>99     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>100     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>101     fn set_vring_addr(
102         &mut self,
103         index: u32,
104         flags: VhostUserVringAddrFlags,
105         descriptor: u64,
106         used: u64,
107         available: u64,
108         log: u64,
109     ) -> Result<()>;
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>110     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>111     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>112     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>113     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>114     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
115 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>116     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&mut self, features: u64) -> Result<()>117     fn set_protocol_features(&mut self, features: u64) -> Result<()>;
get_queue_num(&mut self) -> Result<u64>118     fn get_queue_num(&mut self) -> Result<u64>;
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>119     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>120     fn get_config(
121         &mut self,
122         offset: u32,
123         size: u32,
124         flags: VhostUserConfigFlags,
125     ) -> Result<Vec<u8>>;
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>126     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&mut self, _vu_req: File)127     fn set_slave_req_fd(&mut self, _vu_req: File) {}
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>128     fn get_inflight_fd(
129         &mut self,
130         inflight: &VhostUserInflight,
131     ) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>132     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&mut self) -> Result<u64>133     fn get_max_mem_slots(&mut self) -> Result<u64>;
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>134     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>135     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
136 }
137 
138 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
protocol(&self) -> Protocol139     fn protocol(&self) -> Protocol {
140         self.lock().unwrap().protocol()
141     }
142 
set_owner(&self) -> Result<()>143     fn set_owner(&self) -> Result<()> {
144         self.lock().unwrap().set_owner()
145     }
146 
reset_owner(&self) -> Result<()>147     fn reset_owner(&self) -> Result<()> {
148         self.lock().unwrap().reset_owner()
149     }
150 
get_features(&self) -> Result<u64>151     fn get_features(&self) -> Result<u64> {
152         self.lock().unwrap().get_features()
153     }
154 
set_features(&self, features: u64) -> Result<()>155     fn set_features(&self, features: u64) -> Result<()> {
156         self.lock().unwrap().set_features(features)
157     }
158 
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>159     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
160         self.lock().unwrap().set_mem_table(ctx, files)
161     }
162 
set_vring_num(&self, index: u32, num: u32) -> Result<()>163     fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
164         self.lock().unwrap().set_vring_num(index, num)
165     }
166 
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>167     fn set_vring_addr(
168         &self,
169         index: u32,
170         flags: VhostUserVringAddrFlags,
171         descriptor: u64,
172         used: u64,
173         available: u64,
174         log: u64,
175     ) -> Result<()> {
176         self.lock()
177             .unwrap()
178             .set_vring_addr(index, flags, descriptor, used, available, log)
179     }
180 
set_vring_base(&self, index: u32, base: u32) -> Result<()>181     fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
182         self.lock().unwrap().set_vring_base(index, base)
183     }
184 
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>185     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
186         self.lock().unwrap().get_vring_base(index)
187     }
188 
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>189     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
190         self.lock().unwrap().set_vring_kick(index, fd)
191     }
192 
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>193     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
194         self.lock().unwrap().set_vring_call(index, fd)
195     }
196 
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>197     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
198         self.lock().unwrap().set_vring_err(index, fd)
199     }
200 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>201     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
202         self.lock().unwrap().get_protocol_features()
203     }
204 
set_protocol_features(&self, features: u64) -> Result<()>205     fn set_protocol_features(&self, features: u64) -> Result<()> {
206         self.lock().unwrap().set_protocol_features(features)
207     }
208 
get_queue_num(&self) -> Result<u64>209     fn get_queue_num(&self) -> Result<u64> {
210         self.lock().unwrap().get_queue_num()
211     }
212 
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>213     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
214         self.lock().unwrap().set_vring_enable(index, enable)
215     }
216 
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>217     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
218         self.lock().unwrap().get_config(offset, size, flags)
219     }
220 
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>221     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
222         self.lock().unwrap().set_config(offset, buf, flags)
223     }
224 
set_slave_req_fd(&self, vu_req: File)225     fn set_slave_req_fd(&self, vu_req: File) {
226         self.lock().unwrap().set_slave_req_fd(vu_req)
227     }
228 
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>229     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
230         self.lock().unwrap().get_inflight_fd(inflight)
231     }
232 
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>233     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
234         self.lock().unwrap().set_inflight_fd(inflight, file)
235     }
236 
get_max_mem_slots(&self) -> Result<u64>237     fn get_max_mem_slots(&self) -> Result<u64> {
238         self.lock().unwrap().get_max_mem_slots()
239     }
240 
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>241     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
242         self.lock().unwrap().add_mem_region(region, fd)
243     }
244 
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>245     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
246         self.lock().unwrap().remove_mem_region(region)
247     }
248 }
249 
250 /// Abstracts |Endpoint| related operations for vhost-user slave implementations.
251 pub struct SlaveReqHelper<E: Endpoint<MasterReq>> {
252     /// Underlying endpoint for communication.
253     endpoint: E,
254 
255     /// Protocol used for the communication.
256     protocol: Protocol,
257 
258     /// Sending ack for messages without payload.
259     reply_ack_enabled: bool,
260 }
261 
262 impl<E: Endpoint<MasterReq>> SlaveReqHelper<E> {
263     /// Creates a new |SlaveReqHelper| instance with an |Endpoint| underneath it.
new(endpoint: E, protocol: Protocol) -> Self264     pub fn new(endpoint: E, protocol: Protocol) -> Self {
265         SlaveReqHelper {
266             endpoint,
267             protocol,
268             reply_ack_enabled: false,
269         }
270     }
271 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>272     fn new_reply_header<T: Sized>(
273         &self,
274         req: &VhostUserMsgHeader<MasterReq>,
275         payload_size: usize,
276     ) -> Result<VhostUserMsgHeader<MasterReq>> {
277         if mem::size_of::<T>() > MAX_MSG_SIZE
278             || payload_size > MAX_MSG_SIZE
279             || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
280         {
281             return Err(Error::InvalidParam);
282         }
283 
284         Ok(VhostUserMsgHeader::new(
285             req.get_code(),
286             VhostUserHeaderFlag::REPLY.bits(),
287             (mem::size_of::<T>() + payload_size) as u32,
288         ))
289     }
290 
291     /// Sends reply back to Vhost Master in response to a message.
send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, success: bool, ) -> Result<()>292     pub fn send_ack_message(
293         &mut self,
294         req: &VhostUserMsgHeader<MasterReq>,
295         success: bool,
296     ) -> Result<()> {
297         if self.reply_ack_enabled && req.is_need_reply() {
298             let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
299             let val = if success { 0 } else { 1 };
300             let msg = VhostUserU64::new(val);
301             self.endpoint.send_message(&hdr, &msg, None)?;
302         }
303         Ok(())
304     }
305 
send_reply_message<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>306     fn send_reply_message<T: Sized + DataInit>(
307         &mut self,
308         req: &VhostUserMsgHeader<MasterReq>,
309         msg: &T,
310     ) -> Result<()> {
311         let hdr = self.new_reply_header::<T>(req, 0)?;
312         self.endpoint.send_message(&hdr, msg, None)?;
313         Ok(())
314     }
315 
send_reply_with_payload<T: Sized + DataInit>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>316     fn send_reply_with_payload<T: Sized + DataInit>(
317         &mut self,
318         req: &VhostUserMsgHeader<MasterReq>,
319         msg: &T,
320         payload: &[u8],
321     ) -> Result<()> {
322         let hdr = self.new_reply_header::<T>(req, payload.len())?;
323         self.endpoint
324             .send_message_with_payload(&hdr, msg, payload, None)?;
325         Ok(())
326     }
327 
328     /// Parses an incoming |SET_VRING_KICK| or |SET_VRING_CALL| message into a
329     /// Vring number and an fd.
handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>330     pub fn handle_vring_fd_request(
331         &mut self,
332         buf: &[u8],
333         files: Option<Vec<File>>,
334     ) -> Result<(u8, Option<File>)> {
335         if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
336             return Err(Error::InvalidMessage);
337         }
338         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
339         if !msg.is_valid() {
340             return Err(Error::InvalidMessage);
341         }
342 
343         // Virtio-vhost-user protocol doesn't send FDs.
344         if self.protocol == Protocol::Virtio {
345             return Ok((msg.value as u8, None));
346         }
347 
348         // Bits (0-7) of the payload contain the vring index. Bit 8 is the
349         // invalid FD flag (VHOST_USER_VRING_NOFD_MASK).
350         // This bit is set when there is no file descriptor
351         // in the ancillary data. This signals that polling will be used
352         // instead of waiting for the call.
353         // If Bit 8 is unset, the data must contain a file descriptor.
354         let has_fd = (msg.value & 0x100u64) == 0;
355 
356         let file = take_single_file(files);
357 
358         if has_fd && file.is_none() || !has_fd && file.is_some() {
359             return Err(Error::InvalidMessage);
360         }
361 
362         Ok((msg.value as u8, file))
363     }
364 }
365 
366 impl<E: Endpoint<MasterReq>> AsRef<E> for SlaveReqHelper<E> {
as_ref(&self) -> &E367     fn as_ref(&self) -> &E {
368         &self.endpoint
369     }
370 }
371 
372 impl<E: Endpoint<MasterReq>> AsMut<E> for SlaveReqHelper<E> {
as_mut(&mut self) -> &mut E373     fn as_mut(&mut self) -> &mut E {
374         &mut self.endpoint
375     }
376 }
377 
378 impl<E: Endpoint<MasterReq> + AsRawDescriptor> AsRawDescriptor for SlaveReqHelper<E> {
as_raw_descriptor(&self) -> RawDescriptor379     fn as_raw_descriptor(&self) -> RawDescriptor {
380         self.endpoint.as_raw_descriptor()
381     }
382 }
383 
384 /// Server to handle service requests from masters from the master communication channel.
385 ///
386 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
387 /// masters on the master communication channel. It's actually a proxy invoking the registered
388 /// handler implementing [VhostUserSlaveReqHandler] to do the real work.
389 ///
390 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
391 /// Socket, so it gets simpler to recover from disconnect.
392 ///
393 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
394 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
395 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> {
396     slave_req_helper: SlaveReqHelper<E>,
397     // the vhost-user backend device object
398     backend: Arc<S>,
399 
400     virtio_features: u64,
401     acked_virtio_features: u64,
402     protocol_features: VhostUserProtocolFeatures,
403     acked_protocol_features: u64,
404 
405     // whether the endpoint has encountered any failure
406     error: Option<i32>,
407 }
408 
409 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S, MasterReqEndpoint> {
410     /// Create a vhost-user slave endpoint from a connected socket.
from_stream(socket: SystemStream, backend: Arc<S>) -> Self411     pub fn from_stream(socket: SystemStream, backend: Arc<S>) -> Self {
412         Self::new(MasterReqEndpoint::from(socket), backend)
413     }
414 }
415 
416 impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E> {
417     /// Create a vhost-user slave endpoint.
new(endpoint: E, backend: Arc<S>) -> Self418     pub(super) fn new(endpoint: E, backend: Arc<S>) -> Self {
419         SlaveReqHandler {
420             slave_req_helper: SlaveReqHelper::new(endpoint, backend.protocol()),
421             backend,
422             virtio_features: 0,
423             acked_virtio_features: 0,
424             protocol_features: VhostUserProtocolFeatures::empty(),
425             acked_protocol_features: 0,
426             error: None,
427         }
428     }
429 
430     /// Create a new vhost-user slave endpoint.
431     ///
432     /// # Arguments
433     /// * - `path` - path of Unix domain socket listener to connect to
434     /// * - `backend` - handler for requests from the master to the slave
connect(path: &str, backend: Arc<S>) -> Result<Self>435     pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
436         Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
437     }
438 
439     /// Mark endpoint as failed with specified error code.
set_failed(&mut self, error: i32)440     pub fn set_failed(&mut self, error: i32) {
441         self.error = Some(error);
442     }
443 
444     /// Main entrance to request from the communication channel.
445     ///
446     /// Receive and handle one incoming request message from the vmm. The caller needs to:
447     /// - serialize calls to this function
448     /// - decide what to do when error happens
449     /// - optional recover from failure
450     ///
451     /// # Return:
452     /// * - `Ok(())`: one request was successfully handled.
453     /// * - `Err(ClientExit)`: the vmm closed the connection properly. This isn't an actual failure.
454     /// * - `Err(Disconnect)`: the connection was closed unexpectedly.
455     /// * - `Err(InvalidMessage)`: the vmm sent a illegal message.
456     /// * - other errors: failed to handle a request.
handle_request(&mut self) -> Result<()>457     pub fn handle_request(&mut self) -> Result<()> {
458         // Return error if the endpoint is already in failed state.
459         self.check_state()?;
460 
461         // The underlying communication channel is a Unix domain socket in
462         // stream mode, and recvmsg() is a little tricky here. To successfully
463         // receive attached file descriptors, we need to receive messages and
464         // corresponding attached file descriptors in this way:
465         // . recv messsage header and optional attached file
466         // . validate message header
467         // . recv optional message body and payload according size field in
468         //   message header
469         // . validate message body and optional payload
470         let (hdr, files) = match self.slave_req_helper.endpoint.recv_header() {
471             Ok((hdr, files)) => (hdr, files),
472             Err(Error::Disconnect) => {
473                 // If the client closed the connection before sending a header, this should be
474                 // handled as a legal exit.
475                 return Err(Error::ClientExit);
476             }
477             Err(e) => {
478                 return Err(e);
479             }
480         };
481 
482         self.check_attached_files(&hdr, &files)?;
483 
484         let buf = match hdr.get_size() {
485             0 => vec![0u8; 0],
486             len => {
487                 let rbuf = self.slave_req_helper.endpoint.recv_data(len as usize)?;
488                 if rbuf.len() != len as usize {
489                     return Err(Error::InvalidMessage);
490                 }
491                 rbuf
492             }
493         };
494         let size = buf.len();
495 
496         match hdr.get_code() {
497             MasterReq::SET_OWNER => {
498                 self.check_request_size(&hdr, size, 0)?;
499                 let res = self.backend.set_owner();
500                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
501                 res?;
502             }
503             MasterReq::RESET_OWNER => {
504                 self.check_request_size(&hdr, size, 0)?;
505                 let res = self.backend.reset_owner();
506                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
507                 res?;
508             }
509             MasterReq::GET_FEATURES => {
510                 self.check_request_size(&hdr, size, 0)?;
511                 let features = self.backend.get_features()?;
512                 let msg = VhostUserU64::new(features);
513                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
514                 self.virtio_features = features;
515                 self.update_reply_ack_flag();
516             }
517             MasterReq::SET_FEATURES => {
518                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
519                 let res = self.backend.set_features(msg.value);
520                 self.acked_virtio_features = msg.value;
521                 self.update_reply_ack_flag();
522                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
523                 res?;
524             }
525             MasterReq::SET_MEM_TABLE => {
526                 let res = self.set_mem_table(&hdr, size, &buf, files);
527                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
528                 res?;
529             }
530             MasterReq::SET_VRING_NUM => {
531                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
532                 let res = self.backend.set_vring_num(msg.index, msg.num);
533                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
534                 res?;
535             }
536             MasterReq::SET_VRING_ADDR => {
537                 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
538                 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
539                     Some(val) => val,
540                     None => return Err(Error::InvalidMessage),
541                 };
542                 let res = self.backend.set_vring_addr(
543                     msg.index,
544                     flags,
545                     msg.descriptor,
546                     msg.used,
547                     msg.available,
548                     msg.log,
549                 );
550                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
551                 res?;
552             }
553             MasterReq::SET_VRING_BASE => {
554                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
555                 let res = self.backend.set_vring_base(msg.index, msg.num);
556                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
557                 res?;
558             }
559             MasterReq::GET_VRING_BASE => {
560                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
561                 let reply = self.backend.get_vring_base(msg.index)?;
562                 self.slave_req_helper.send_reply_message(&hdr, &reply)?;
563             }
564             MasterReq::SET_VRING_CALL => {
565                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
566                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
567                 let res = self.backend.set_vring_call(index, file);
568                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
569                 res?;
570             }
571             MasterReq::SET_VRING_KICK => {
572                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
573                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
574                 let res = self.backend.set_vring_kick(index, file);
575                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
576                 res?;
577             }
578             MasterReq::SET_VRING_ERR => {
579                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
580                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
581                 let res = self.backend.set_vring_err(index, file);
582                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
583                 res?;
584             }
585             MasterReq::GET_PROTOCOL_FEATURES => {
586                 self.check_request_size(&hdr, size, 0)?;
587                 let features = self.backend.get_protocol_features()?;
588                 let msg = VhostUserU64::new(features.bits());
589                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
590                 self.protocol_features = features;
591                 self.update_reply_ack_flag();
592             }
593             MasterReq::SET_PROTOCOL_FEATURES => {
594                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
595                 let res = self.backend.set_protocol_features(msg.value);
596                 self.acked_protocol_features = msg.value;
597                 self.update_reply_ack_flag();
598                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
599                 res?;
600             }
601             MasterReq::GET_QUEUE_NUM => {
602                 if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
603                     return Err(Error::InvalidOperation);
604                 }
605                 self.check_request_size(&hdr, size, 0)?;
606                 let num = self.backend.get_queue_num()?;
607                 let msg = VhostUserU64::new(num);
608                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
609             }
610             MasterReq::SET_VRING_ENABLE => {
611                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
612                 if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
613                     == 0
614                 {
615                     return Err(Error::InvalidOperation);
616                 }
617                 let enable = match msg.num {
618                     1 => true,
619                     0 => false,
620                     _ => return Err(Error::InvalidParam),
621                 };
622 
623                 let res = self.backend.set_vring_enable(msg.index, enable);
624                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
625                 res?;
626             }
627             MasterReq::GET_CONFIG => {
628                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
629                     return Err(Error::InvalidOperation);
630                 }
631                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
632                 self.get_config(&hdr, &buf)?;
633             }
634             MasterReq::SET_CONFIG => {
635                 if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
636                     return Err(Error::InvalidOperation);
637                 }
638                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
639                 let res = self.set_config(size, &buf);
640                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
641                 res?;
642             }
643             MasterReq::SET_SLAVE_REQ_FD => {
644                 if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
645                     return Err(Error::InvalidOperation);
646                 }
647                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
648                 let res = self.set_slave_req_fd(files);
649                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
650                 res?;
651             }
652             MasterReq::GET_INFLIGHT_FD => {
653                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
654                     == 0
655                 {
656                     return Err(Error::InvalidOperation);
657                 }
658 
659                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
660                 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
661                 let reply_hdr = self
662                     .slave_req_helper
663                     .new_reply_header::<VhostUserInflight>(&hdr, 0)?;
664                 self.slave_req_helper.endpoint.send_message(
665                     &reply_hdr,
666                     &inflight,
667                     Some(&[file.as_raw_descriptor()]),
668                 )?;
669             }
670             MasterReq::SET_INFLIGHT_FD => {
671                 if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
672                     == 0
673                 {
674                     return Err(Error::InvalidOperation);
675                 }
676                 let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
677                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
678                 let res = self.backend.set_inflight_fd(&msg, file);
679                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
680                 res?;
681             }
682             MasterReq::GET_MAX_MEM_SLOTS => {
683                 if self.acked_protocol_features
684                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
685                     == 0
686                 {
687                     return Err(Error::InvalidOperation);
688                 }
689                 self.check_request_size(&hdr, size, 0)?;
690                 let num = self.backend.get_max_mem_slots()?;
691                 let msg = VhostUserU64::new(num);
692                 self.slave_req_helper.send_reply_message(&hdr, &msg)?;
693             }
694             MasterReq::ADD_MEM_REG => {
695                 if self.acked_protocol_features
696                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
697                     == 0
698                 {
699                     return Err(Error::InvalidOperation);
700                 }
701                 let mut files = files.ok_or(Error::InvalidParam)?;
702                 if files.len() != 1 {
703                     return Err(Error::InvalidParam);
704                 }
705                 let msg =
706                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
707                 let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
708                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
709                 res?;
710             }
711             MasterReq::REM_MEM_REG => {
712                 if self.acked_protocol_features
713                     & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
714                     == 0
715                 {
716                     return Err(Error::InvalidOperation);
717                 }
718 
719                 let msg =
720                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
721                 let res = self.backend.remove_mem_region(&msg);
722                 self.slave_req_helper.send_ack_message(&hdr, res.is_ok())?;
723                 res?;
724             }
725             _ => {
726                 return Err(Error::InvalidMessage);
727             }
728         }
729         Ok(())
730     }
731 
set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>732     fn set_mem_table(
733         &mut self,
734         hdr: &VhostUserMsgHeader<MasterReq>,
735         size: usize,
736         buf: &[u8],
737         files: Option<Vec<File>>,
738     ) -> Result<()> {
739         self.check_request_size(hdr, size, hdr.get_size() as usize)?;
740 
741         // check message size is consistent
742         let hdrsize = mem::size_of::<VhostUserMemory>();
743         if size < hdrsize {
744             return Err(Error::InvalidMessage);
745         }
746         let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
747         if !msg.is_valid() {
748             return Err(Error::InvalidMessage);
749         }
750         if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
751             return Err(Error::InvalidMessage);
752         }
753 
754         let files = match self.slave_req_helper.protocol {
755             Protocol::Regular => {
756                 // validate number of fds matching number of memory regions
757                 let files = files.ok_or(Error::InvalidMessage)?;
758                 if files.len() != msg.num_regions as usize {
759                     return Err(Error::InvalidMessage);
760                 }
761                 files
762             }
763             Protocol::Virtio => vec![],
764         };
765 
766         // Validate memory regions
767         let regions = unsafe {
768             slice::from_raw_parts(
769                 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
770                 msg.num_regions as usize,
771             )
772         };
773         for region in regions.iter() {
774             if !region.is_valid() {
775                 return Err(Error::InvalidMessage);
776             }
777         }
778 
779         self.backend.set_mem_table(regions, files)
780     }
781 
get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>782     fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
783         let payload_offset = mem::size_of::<VhostUserConfig>();
784         if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
785             return Err(Error::InvalidMessage);
786         }
787         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
788         if !msg.is_valid() {
789             return Err(Error::InvalidMessage);
790         }
791         if buf.len() - payload_offset != msg.size as usize {
792             return Err(Error::InvalidMessage);
793         }
794         let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
795             Some(val) => val,
796             None => return Err(Error::InvalidMessage),
797         };
798         let res = self.backend.get_config(msg.offset, msg.size, flags);
799 
800         // vhost-user slave's payload size MUST match master's request
801         // on success, uses zero length of payload to indicate an error
802         // to vhost-user master.
803         match res {
804             Ok(ref buf) if buf.len() == msg.size as usize => {
805                 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
806                 self.slave_req_helper
807                     .send_reply_with_payload(hdr, &reply, buf.as_slice())?;
808             }
809             Ok(_) => {
810                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
811                 self.slave_req_helper.send_reply_message(hdr, &reply)?;
812             }
813             Err(_) => {
814                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
815                 self.slave_req_helper.send_reply_message(hdr, &reply)?;
816             }
817         }
818         Ok(())
819     }
820 
set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>821     fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
822         if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
823             return Err(Error::InvalidMessage);
824         }
825         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
826         if !msg.is_valid() {
827             return Err(Error::InvalidMessage);
828         }
829         if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
830             return Err(Error::InvalidMessage);
831         }
832         let flags: VhostUserConfigFlags = match VhostUserConfigFlags::from_bits(msg.flags) {
833             Some(val) => val,
834             None => return Err(Error::InvalidMessage),
835         };
836 
837         self.backend.set_config(msg.offset, buf, flags)
838     }
839 
set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>840     fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
841         if cfg!(windows) {
842             unimplemented!();
843         } else {
844             let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
845             self.backend.set_slave_req_fd(file);
846             Ok(())
847         }
848     }
849 
handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>850     fn handle_vring_fd_request(
851         &mut self,
852         buf: &[u8],
853         files: Option<Vec<File>>,
854     ) -> Result<(u8, Option<File>)> {
855         self.slave_req_helper.handle_vring_fd_request(buf, files)
856     }
857 
check_state(&self) -> Result<()>858     fn check_state(&self) -> Result<()> {
859         match self.error {
860             Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
861             None => Ok(()),
862         }
863     }
864 
check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>865     fn check_request_size(
866         &self,
867         hdr: &VhostUserMsgHeader<MasterReq>,
868         size: usize,
869         expected: usize,
870     ) -> Result<()> {
871         if hdr.get_size() as usize != expected
872             || hdr.is_reply()
873             || hdr.get_version() != 0x1
874             || size != expected
875         {
876             return Err(Error::InvalidMessage);
877         }
878         Ok(())
879     }
880 
check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>881     fn check_attached_files(
882         &self,
883         hdr: &VhostUserMsgHeader<MasterReq>,
884         files: &Option<Vec<File>>,
885     ) -> Result<()> {
886         match hdr.get_code() {
887             MasterReq::SET_MEM_TABLE
888             | MasterReq::SET_VRING_CALL
889             | MasterReq::SET_VRING_KICK
890             | MasterReq::SET_VRING_ERR
891             | MasterReq::SET_LOG_BASE
892             | MasterReq::SET_LOG_FD
893             | MasterReq::SET_SLAVE_REQ_FD
894             | MasterReq::SET_INFLIGHT_FD
895             | MasterReq::ADD_MEM_REG => Ok(()),
896             _ if files.is_some() => Err(Error::InvalidMessage),
897             _ => Ok(()),
898         }
899     }
900 
extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>901     fn extract_request_body<T: Sized + DataInit + VhostUserMsgValidator>(
902         &self,
903         hdr: &VhostUserMsgHeader<MasterReq>,
904         size: usize,
905         buf: &[u8],
906     ) -> Result<T> {
907         self.check_request_size(hdr, size, mem::size_of::<T>())?;
908         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
909         if !msg.is_valid() {
910             return Err(Error::InvalidMessage);
911         }
912         Ok(msg)
913     }
914 
update_reply_ack_flag(&mut self)915     fn update_reply_ack_flag(&mut self) {
916         let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
917         let pflag = VhostUserProtocolFeatures::REPLY_ACK;
918         if (self.virtio_features & vflag) != 0
919             && self.protocol_features.contains(pflag)
920             && (self.acked_protocol_features & pflag.bits()) != 0
921         {
922             self.slave_req_helper.reply_ack_enabled = true;
923         } else {
924             self.slave_req_helper.reply_ack_enabled = false;
925         }
926     }
927 }
928 
929 impl<S: VhostUserSlaveReqHandler, E: AsRawDescriptor + Endpoint<MasterReq>> AsRawDescriptor
930     for SlaveReqHandler<S, E>
931 {
as_raw_descriptor(&self) -> RawDescriptor932     fn as_raw_descriptor(&self) -> RawDescriptor {
933         // TODO(b/221882601): figure out if this used for polling.
934         self.slave_req_helper.endpoint.as_raw_descriptor()
935     }
936 }
937 
938 #[cfg(test)]
939 mod tests {
940     use base::INVALID_DESCRIPTOR;
941 
942     use super::*;
943     use crate::{dummy_slave::DummySlaveReqHandler, MasterReqEndpoint, SystemStream};
944 
945     #[test]
test_slave_req_handler_new()946     fn test_slave_req_handler_new() {
947         let (p1, _p2) = SystemStream::pair().unwrap();
948         let endpoint = MasterReqEndpoint::from(p1);
949         let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
950         let mut handler = SlaveReqHandler::new(endpoint, backend);
951 
952         handler.check_state().unwrap();
953         handler.set_failed(libc::EAGAIN);
954         handler.check_state().unwrap_err();
955         assert!(handler.as_raw_descriptor() != INVALID_DESCRIPTOR);
956     }
957 }
958