• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserBackend` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserBackend` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //!   /* fields */
20 //! }
21 //!
22 //! impl VhostUserBackend for MyBackend {
23 //!   /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //!   let backend = MyBackend { /* initialize fields */ };
28 //!   let handler = DeviceRequestHandler::new(backend);
29 //!   let socket = std::path::Path("/path/to/socket");
30 //!   let ex = cros_async::Executor::new()?;
31 //!
32 //!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //!     eprintln!("error happened: {}", e);
34 //!   }
35 //!   Ok(())
36 //! }
37 //! ```
38 //!
39 // Implementation note:
40 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
41 // protocol. DeviceRequestHandler implements the VhostUserSlaveReqHandlerMut trait from vmm_vhost,
42 // and includes some common code for setting up guest memory and managing partially configured
43 // vrings. DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request()
44 // when it becomes readable. handle_request() reads and parses the message and then calls one of the
45 // VhostUserSlaveReqHandlerMut trait methods. These dispatch back to the supplied VhostUserBackend
46 // implementation (this is what our devices implement).
47 
48 pub(super) mod sys;
49 
50 use std::collections::BTreeMap;
51 use std::convert::From;
52 use std::convert::TryFrom;
53 use std::fs::File;
54 use std::num::Wrapping;
55 #[cfg(unix)]
56 use std::os::unix::io::AsRawFd;
57 use std::sync::Arc;
58 
59 use anyhow::bail;
60 use anyhow::Context;
61 #[cfg(unix)]
62 use base::clear_fd_flags;
63 use base::error;
64 use base::Event;
65 use base::FromRawDescriptor;
66 use base::IntoRawDescriptor;
67 use base::Protection;
68 use base::SafeDescriptor;
69 use base::SharedMemory;
70 use sys::Doorbell;
71 use vm_control::VmMemorySource;
72 use vm_memory::GuestAddress;
73 use vm_memory::GuestMemory;
74 use vm_memory::MemoryRegion;
75 use vmm_vhost::connection::Endpoint;
76 use vmm_vhost::message::SlaveReq;
77 use vmm_vhost::message::VhostSharedMemoryRegion;
78 use vmm_vhost::message::VhostUserConfigFlags;
79 use vmm_vhost::message::VhostUserGpuMapMsg;
80 use vmm_vhost::message::VhostUserInflight;
81 use vmm_vhost::message::VhostUserMemoryRegion;
82 use vmm_vhost::message::VhostUserProtocolFeatures;
83 use vmm_vhost::message::VhostUserShmemMapMsg;
84 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
85 use vmm_vhost::message::VhostUserShmemUnmapMsg;
86 use vmm_vhost::message::VhostUserSingleMemoryRegion;
87 use vmm_vhost::message::VhostUserVirtioFeatures;
88 use vmm_vhost::message::VhostUserVringAddrFlags;
89 use vmm_vhost::message::VhostUserVringState;
90 use vmm_vhost::Error as VhostError;
91 use vmm_vhost::Protocol;
92 use vmm_vhost::Result as VhostResult;
93 use vmm_vhost::Slave;
94 use vmm_vhost::VhostUserMasterReqHandler;
95 use vmm_vhost::VhostUserSlaveReqHandlerMut;
96 
97 use crate::virtio::Queue;
98 use crate::virtio::SharedMemoryMapper;
99 use crate::virtio::SharedMemoryRegion;
100 use crate::virtio::SignalableInterrupt;
101 
102 /// Largest valid number of entries in a virtqueue.
103 const MAX_VRING_LEN: u16 = 32768;
104 
105 /// An event to deliver an interrupt to the guest.
106 ///
107 /// Unlike `devices::Interrupt`, this doesn't support interrupt status and signal resampling.
108 // TODO(b/187487351): To avoid sending unnecessary events, we might want to support interrupt
109 // status. For this purpose, we need a mechanism to share interrupt status between the vmm and the
110 // device process.
111 #[derive(Clone)]
112 pub struct CallEvent(Arc<Event>);
113 
114 impl CallEvent {
115     #[cfg_attr(windows, allow(dead_code))]
into_inner(self) -> Event116     pub fn into_inner(self) -> Event {
117         Arc::try_unwrap(self.0).unwrap()
118     }
119 }
120 
121 impl SignalableInterrupt for CallEvent {
signal(&self, _vector: u16, _interrupt_status_mask: u32)122     fn signal(&self, _vector: u16, _interrupt_status_mask: u32) {
123         self.0.signal().unwrap();
124     }
125 
signal_config_changed(&self)126     fn signal_config_changed(&self) {} // TODO(dgreid)
127 
get_resample_evt(&self) -> Option<&Event>128     fn get_resample_evt(&self) -> Option<&Event> {
129         None
130     }
131 
do_interrupt_resample(&self)132     fn do_interrupt_resample(&self) {}
133 }
134 
135 impl From<File> for CallEvent {
from(file: File) -> Self136     fn from(file: File) -> Self {
137         // Safe because we own the file.
138         CallEvent(Arc::new(unsafe {
139             Event::from_raw_descriptor(file.into_raw_descriptor())
140         }))
141     }
142 }
143 
144 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
145 /// used to translate messages from the vmm to guest offsets.
146 #[derive(Default)]
147 pub struct MappingInfo {
148     pub vmm_addr: u64,
149     pub guest_phys: u64,
150     pub size: u64,
151 }
152 
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>153 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
154     for map in maps {
155         if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
156             return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
157         }
158     }
159     Err(VhostError::InvalidMessage)
160 }
161 
162 /// Trait for vhost-user backend.
163 pub trait VhostUserBackend {
164     /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize165     fn max_queue_num(&self) -> usize;
166 
167     /// The set of feature bits that this backend supports.
features(&self) -> u64168     fn features(&self) -> u64;
169 
170     /// Acknowledges that this set of features should be enabled.
ack_features(&mut self, value: u64) -> anyhow::Result<()>171     fn ack_features(&mut self, value: u64) -> anyhow::Result<()>;
172 
173     /// Returns the set of enabled features.
acked_features(&self) -> u64174     fn acked_features(&self) -> u64;
175 
176     /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures177     fn protocol_features(&self) -> VhostUserProtocolFeatures;
178 
179     /// Acknowledges that this set of protocol features should be enabled.
ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>180     fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>;
181 
182     /// Returns the set of enabled protocol features.
acked_protocol_features(&self) -> u64183     fn acked_protocol_features(&self) -> u64;
184 
185     /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])186     fn read_config(&self, offset: u64, dst: &mut [u8]);
187 
188     /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])189     fn write_config(&self, _offset: u64, _data: &[u8]) {}
190 
191     /// Indicates that the backend should start processing requests for virtio queue number `idx`.
192     /// This method must not block the current thread so device backends should either spawn an
193     /// async task or another thread to handle messages from the Queue.
start_queue( &mut self, idx: usize, queue: Queue, mem: GuestMemory, doorbell: Doorbell, kick_evt: Event, ) -> anyhow::Result<()>194     fn start_queue(
195         &mut self,
196         idx: usize,
197         queue: Queue,
198         mem: GuestMemory,
199         doorbell: Doorbell,
200         kick_evt: Event,
201     ) -> anyhow::Result<()>;
202 
203     /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
stop_queue(&mut self, idx: usize)204     fn stop_queue(&mut self, idx: usize);
205 
206     /// Resets the vhost-user backend.
reset(&mut self)207     fn reset(&mut self);
208 
209     /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>210     fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
211         None
212     }
213 
214     /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
215     /// handling.
216     ///
217     /// This method will be called when `VhostUserProtocolFeatures::SLAVE_REQ` is
218     /// negotiated.
set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection)219     fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {
220         error!("set_backend_req_connection is not implemented");
221     }
222 }
223 
224 /// A virtio ring entry.
225 struct Vring {
226     queue: Queue,
227     doorbell: Option<Doorbell>,
228     enabled: bool,
229 }
230 
231 impl Vring {
new(max_size: u16) -> Self232     fn new(max_size: u16) -> Self {
233         Self {
234             queue: Queue::new(max_size),
235             doorbell: None,
236             enabled: false,
237         }
238     }
239 
reset(&mut self)240     fn reset(&mut self) {
241         self.queue.reset();
242         self.doorbell = None;
243         self.enabled = false;
244     }
245 }
246 
247 /// Trait for defining vhost-user ops that are platform-dependent.
248 pub trait VhostUserPlatformOps {
249     /// Returns the protocol implemented by these platform ops.
protocol(&self) -> Protocol250     fn protocol(&self) -> Protocol;
251     /// Create the guest memory for the backend.
252     ///
253     /// `contexts` and `files` must be the same size, and provide a description of the memory
254     /// regions to map as well as the file descriptors from which to obtain the memory backing these
255     /// regions, respectively.
256     ///
257     /// The returned tuple contains the constructed `GuestMemory` from these memory contexts, as
258     /// well as a vector describing all the mappings described by these contexts.
259 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>260     fn set_mem_table(
261         &mut self,
262         contexts: &[VhostUserMemoryRegion],
263         files: Vec<File>,
264     ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>;
265 
266     /// Return an `Event` that will be signaled by the frontend whenever vring `index` should be
267     /// processed.
268     ///
269     /// For protocols that support providing that event using a file descriptor (`Regular`), it is
270     /// provided by `file`. For other protocols, `file` will be `None`.
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<Event>271     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<Event>;
272 
273     /// Return a `Doorbell` that the backend will signal whenever it puts used buffers for vring
274     /// `index`.
275     ///
276     /// For protocols that support listening to a file descriptor (`Regular`), `file` provides a
277     /// file descriptor from which the `Doorbell` should be built. For other protocols, it will be
278     /// `None`.
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<Doorbell>279     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<Doorbell>;
280 }
281 
282 /// Ops for running vhost-user over a stream (i.e. regular protocol).
283 pub(super) struct VhostUserRegularOps;
284 
285 impl VhostUserPlatformOps for VhostUserRegularOps {
protocol(&self) -> Protocol286     fn protocol(&self) -> Protocol {
287         Protocol::Regular
288     }
289 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>290     fn set_mem_table(
291         &mut self,
292         contexts: &[VhostUserMemoryRegion],
293         files: Vec<File>,
294     ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
295         if files.len() != contexts.len() {
296             return Err(VhostError::InvalidParam);
297         }
298 
299         let mut regions = Vec::with_capacity(files.len());
300         for (region, file) in contexts.iter().zip(files.into_iter()) {
301             let region = MemoryRegion::new_from_shm(
302                 region.memory_size,
303                 GuestAddress(region.guest_phys_addr),
304                 region.mmap_offset,
305                 Arc::new(
306                     SharedMemory::from_safe_descriptor(
307                         SafeDescriptor::from(file),
308                         Some(region.memory_size),
309                     )
310                     .unwrap(),
311                 ),
312             )
313             .map_err(|e| {
314                 error!("failed to create a memory region: {}", e);
315                 VhostError::InvalidOperation
316             })?;
317             regions.push(region);
318         }
319         let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
320             error!("failed to create guest memory: {}", e);
321             VhostError::InvalidOperation
322         })?;
323 
324         let vmm_maps = contexts
325             .iter()
326             .map(|region| MappingInfo {
327                 vmm_addr: region.user_addr,
328                 guest_phys: region.guest_phys_addr,
329                 size: region.memory_size,
330             })
331             .collect();
332         Ok((guest_mem, vmm_maps))
333     }
334 
set_vring_kick(&mut self, _index: u8, file: Option<File>) -> VhostResult<Event>335     fn set_vring_kick(&mut self, _index: u8, file: Option<File>) -> VhostResult<Event> {
336         let file = file.ok_or(VhostError::InvalidParam)?;
337         // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
338         // values via `next_val()` later.
339         // This is only required (and can only be done) on Unix platforms.
340         #[cfg(unix)]
341         if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
342             error!("failed to remove O_NONBLOCK for kick fd: {}", e);
343             return Err(VhostError::InvalidParam);
344         }
345 
346         // Safe because we own the file.
347         Ok(unsafe { Event::from_raw_descriptor(file.into_raw_descriptor()) })
348     }
349 
set_vring_call(&mut self, _index: u8, file: Option<File>) -> VhostResult<Doorbell>350     fn set_vring_call(&mut self, _index: u8, file: Option<File>) -> VhostResult<Doorbell> {
351         let file = file.ok_or(VhostError::InvalidParam)?;
352         Ok(
353             // `Doorbell` is defined as `CallEvent` on Windows, prevent clippy from giving us a
354             // warning about the unneeded conversion.
355             #[allow(clippy::useless_conversion)]
356             Doorbell::from(CallEvent::try_from(file).map_err(|_| {
357                 error!("failed to convert callfd to CallSignal");
358                 VhostError::InvalidParam
359             })?),
360         )
361     }
362 }
363 
364 /// A request handler for devices implementing `VhostUserBackend`.
365 pub struct DeviceRequestHandler {
366     vrings: Vec<Vring>,
367     owned: bool,
368     vmm_maps: Option<Vec<MappingInfo>>,
369     mem: Option<GuestMemory>,
370     backend: Box<dyn VhostUserBackend>,
371     ops: Box<dyn VhostUserPlatformOps>,
372 }
373 
374 impl DeviceRequestHandler {
375     /// Creates a vhost-user handler instance for `backend` with a different set of platform ops
376     /// than the regular vhost-user ones.
new( backend: Box<dyn VhostUserBackend>, ops: Box<dyn VhostUserPlatformOps>, ) -> Self377     pub(crate) fn new(
378         backend: Box<dyn VhostUserBackend>,
379         ops: Box<dyn VhostUserPlatformOps>,
380     ) -> Self {
381         let mut vrings = Vec::with_capacity(backend.max_queue_num());
382         for _ in 0..backend.max_queue_num() {
383             vrings.push(Vring::new(MAX_VRING_LEN));
384         }
385 
386         DeviceRequestHandler {
387             vrings,
388             owned: false,
389             vmm_maps: None,
390             mem: None,
391             backend,
392             ops,
393         }
394     }
395 }
396 
397 impl VhostUserSlaveReqHandlerMut for DeviceRequestHandler {
protocol(&self) -> Protocol398     fn protocol(&self) -> Protocol {
399         self.ops.protocol()
400     }
401 
set_owner(&mut self) -> VhostResult<()>402     fn set_owner(&mut self) -> VhostResult<()> {
403         if self.owned {
404             return Err(VhostError::InvalidOperation);
405         }
406         self.owned = true;
407         Ok(())
408     }
409 
reset_owner(&mut self) -> VhostResult<()>410     fn reset_owner(&mut self) -> VhostResult<()> {
411         self.owned = false;
412         self.backend.reset();
413         Ok(())
414     }
415 
get_features(&mut self) -> VhostResult<u64>416     fn get_features(&mut self) -> VhostResult<u64> {
417         let features = self.backend.features();
418         Ok(features)
419     }
420 
set_features(&mut self, features: u64) -> VhostResult<()>421     fn set_features(&mut self, features: u64) -> VhostResult<()> {
422         if !self.owned {
423             return Err(VhostError::InvalidOperation);
424         }
425 
426         if (features & !(self.backend.features())) != 0 {
427             return Err(VhostError::InvalidParam);
428         }
429 
430         if let Err(e) = self.backend.ack_features(features) {
431             error!("failed to acknowledge features 0x{:x}: {}", features, e);
432             return Err(VhostError::InvalidOperation);
433         }
434 
435         // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
436         // enabled state.
437         // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
438         // disabled state.
439         // Client must not pass data to/from the backend until ring is enabled by
440         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
441         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
442         let acked_features = self.backend.acked_features();
443         let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0;
444         for v in &mut self.vrings {
445             v.enabled = vring_enabled;
446         }
447 
448         Ok(())
449     }
450 
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>451     fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
452         Ok(self.backend.protocol_features())
453     }
454 
set_protocol_features(&mut self, features: u64) -> VhostResult<()>455     fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
456         if let Err(e) = self.backend.ack_protocol_features(features) {
457             error!("failed to set protocol features 0x{:x}: {}", features, e);
458             return Err(VhostError::InvalidOperation);
459         }
460         Ok(())
461     }
462 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>463     fn set_mem_table(
464         &mut self,
465         contexts: &[VhostUserMemoryRegion],
466         files: Vec<File>,
467     ) -> VhostResult<()> {
468         let (guest_mem, vmm_maps) = self.ops.set_mem_table(contexts, files)?;
469         self.mem = Some(guest_mem);
470         self.vmm_maps = Some(vmm_maps);
471         Ok(())
472     }
473 
get_queue_num(&mut self) -> VhostResult<u64>474     fn get_queue_num(&mut self) -> VhostResult<u64> {
475         Ok(self.vrings.len() as u64)
476     }
477 
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>478     fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
479         if index as usize >= self.vrings.len() || num == 0 || num > MAX_VRING_LEN.into() {
480             return Err(VhostError::InvalidParam);
481         }
482         self.vrings[index as usize].queue.set_size(num as u16);
483 
484         Ok(())
485     }
486 
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>487     fn set_vring_addr(
488         &mut self,
489         index: u32,
490         _flags: VhostUserVringAddrFlags,
491         descriptor: u64,
492         used: u64,
493         available: u64,
494         _log: u64,
495     ) -> VhostResult<()> {
496         if index as usize >= self.vrings.len() {
497             return Err(VhostError::InvalidParam);
498         }
499 
500         let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
501         let vring = &mut self.vrings[index as usize];
502         vring
503             .queue
504             .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505         vring
506             .queue
507             .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508         vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509 
510         Ok(())
511     }
512 
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>513     fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514         if index as usize >= self.vrings.len() || base >= MAX_VRING_LEN.into() {
515             return Err(VhostError::InvalidParam);
516         }
517 
518         let vring = &mut self.vrings[index as usize];
519         vring.queue.next_avail = Wrapping(base as u16);
520         vring.queue.next_used = Wrapping(base as u16);
521 
522         Ok(())
523     }
524 
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>525     fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
526         if index as usize >= self.vrings.len() {
527             return Err(VhostError::InvalidParam);
528         }
529 
530         // Quotation from vhost-user spec:
531         // Client must start ring upon receiving a kick (that is, detecting
532         // that file descriptor is readable) on the descriptor specified by
533         // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
534         // VHOST_USER_GET_VRING_BASE.
535         self.backend.stop_queue(index as usize);
536 
537         let vring = &mut self.vrings[index as usize];
538         vring.reset();
539 
540         Ok(VhostUserVringState::new(
541             index,
542             vring.queue.next_avail.0 as u32,
543         ))
544     }
545 
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>546     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
547         if index as usize >= self.vrings.len() {
548             return Err(VhostError::InvalidParam);
549         }
550 
551         let vring = &mut self.vrings[index as usize];
552         if vring.queue.ready() {
553             error!("kick fd cannot replaced after queue is started");
554             return Err(VhostError::InvalidOperation);
555         }
556 
557         let kick_evt = self.ops.set_vring_kick(index, file)?;
558 
559         // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
560         vring.queue.ack_features(self.backend.acked_features());
561         vring.queue.set_ready(true);
562 
563         let queue = match vring.queue.activate() {
564             Ok(queue) => queue,
565             Err(e) => {
566                 error!("failed to activate vring: {:#}", e);
567                 return Err(VhostError::SlaveInternalError);
568             }
569         };
570 
571         let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
572         let mem = self
573             .mem
574             .as_ref()
575             .cloned()
576             .ok_or(VhostError::InvalidOperation)?;
577 
578         if let Err(e) = self
579             .backend
580             .start_queue(index as usize, queue, mem, doorbell, kick_evt)
581         {
582             error!("Failed to start queue {}: {}", index, e);
583             return Err(VhostError::SlaveInternalError);
584         }
585 
586         Ok(())
587     }
588 
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>589     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
590         if index as usize >= self.vrings.len() {
591             return Err(VhostError::InvalidParam);
592         }
593 
594         let doorbell = self.ops.set_vring_call(index, file)?;
595         self.vrings[index as usize].doorbell = Some(doorbell);
596         Ok(())
597     }
598 
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>599     fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
600         // TODO
601         Ok(())
602     }
603 
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>604     fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
605         if index as usize >= self.vrings.len() {
606             return Err(VhostError::InvalidParam);
607         }
608 
609         // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
610         // has been negotiated.
611         if self.backend.acked_features() & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
612             return Err(VhostError::InvalidOperation);
613         }
614 
615         // Slave must not pass data to/from the backend until ring is
616         // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
617         // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
618         // with parameter 0.
619         self.vrings[index as usize].enabled = enable;
620 
621         Ok(())
622     }
623 
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>624     fn get_config(
625         &mut self,
626         offset: u32,
627         size: u32,
628         _flags: VhostUserConfigFlags,
629     ) -> VhostResult<Vec<u8>> {
630         let mut data = vec![0; size as usize];
631         self.backend.read_config(u64::from(offset), &mut data);
632         Ok(data)
633     }
634 
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>635     fn set_config(
636         &mut self,
637         offset: u32,
638         buf: &[u8],
639         _flags: VhostUserConfigFlags,
640     ) -> VhostResult<()> {
641         self.backend.write_config(u64::from(offset), buf);
642         Ok(())
643     }
644 
set_slave_req_fd(&mut self, ep: Box<dyn Endpoint<SlaveReq>>)645     fn set_slave_req_fd(&mut self, ep: Box<dyn Endpoint<SlaveReq>>) {
646         let conn = VhostBackendReqConnection::new(
647             Slave::new(ep),
648             self.backend.get_shared_memory_region().map(|r| r.id),
649         );
650         self.backend.set_backend_req_connection(conn);
651     }
652 
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>653     fn get_inflight_fd(
654         &mut self,
655         _inflight: &VhostUserInflight,
656     ) -> VhostResult<(VhostUserInflight, File)> {
657         unimplemented!("get_inflight_fd");
658     }
659 
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>660     fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
661         unimplemented!("set_inflight_fd");
662     }
663 
get_max_mem_slots(&mut self) -> VhostResult<u64>664     fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
665         //TODO
666         Ok(0)
667     }
668 
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>669     fn add_mem_region(
670         &mut self,
671         _region: &VhostUserSingleMemoryRegion,
672         _fd: File,
673     ) -> VhostResult<()> {
674         //TODO
675         Ok(())
676     }
677 
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>678     fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
679         //TODO
680         Ok(())
681     }
682 
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>683     fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
684         Ok(if let Some(r) = self.backend.get_shared_memory_region() {
685             vec![VhostSharedMemoryRegion::new(r.id, r.length)]
686         } else {
687             Vec::new()
688         })
689     }
690 }
691 
692 /// Indicates the state of backend request connection
693 pub enum VhostBackendReqConnectionState {
694     /// A backend request connection (`VhostBackendReqConnection`) is established
695     Connected(VhostBackendReqConnection),
696     /// No backend request connection has been established yet
697     NoConnection,
698 }
699 
700 /// Keeps track of Vhost user backend request connection.
701 pub struct VhostBackendReqConnection {
702     conn: Slave,
703     shmem_info: Option<ShmemInfo>,
704 }
705 
706 #[derive(Clone)]
707 struct ShmemInfo {
708     shmid: u8,
709     mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
710 }
711 
712 impl VhostBackendReqConnection {
new(conn: Slave, shmid: Option<u8>) -> Self713     pub fn new(conn: Slave, shmid: Option<u8>) -> Self {
714         let shmem_info = shmid.map(|shmid| ShmemInfo {
715             shmid,
716             mapped_regions: BTreeMap::new(),
717         });
718         Self { conn, shmem_info }
719     }
720 
721     /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>722     pub fn send_config_changed(&self) -> anyhow::Result<()> {
723         self.conn
724             .handle_config_change()
725             .context("Could not send config change message")?;
726         Ok(())
727     }
728 
729     /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&mut self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>730     pub fn take_shmem_mapper(&mut self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
731         let shmem_info = self
732             .shmem_info
733             .take()
734             .context("could not take shared memory mapper information")?;
735 
736         Ok(Box::new(VhostShmemMapper {
737             conn: self.conn.clone(),
738             shmem_info,
739         }))
740     }
741 }
742 
743 struct VhostShmemMapper {
744     conn: Slave,
745     shmem_info: ShmemInfo,
746 }
747 
748 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, ) -> anyhow::Result<()>749     fn add_mapping(
750         &mut self,
751         source: VmMemorySource,
752         offset: u64,
753         prot: Protection,
754     ) -> anyhow::Result<()> {
755         // True if we should send gpu_map instead of shmem_map.
756         let is_gpu = matches!(&source, &VmMemorySource::Vulkan { .. });
757 
758         let size = if is_gpu {
759             match source {
760                 VmMemorySource::Vulkan {
761                     descriptor,
762                     handle_type,
763                     memory_idx,
764                     device_id,
765                     size,
766                 } => {
767                     let msg = VhostUserGpuMapMsg::new(
768                         self.shmem_info.shmid,
769                         offset,
770                         size,
771                         memory_idx,
772                         handle_type,
773                         device_id.device_uuid,
774                         device_id.driver_uuid,
775                     );
776                     self.conn
777                         .gpu_map(&msg, &descriptor)
778                         .context("failed to map memory")?;
779                     size
780                 }
781                 _ => unreachable!("inconsistent pattern match"),
782             }
783         } else {
784             let (descriptor, fd_offset, size) = match source {
785                 VmMemorySource::Descriptor {
786                     descriptor,
787                     offset,
788                     size,
789                 } => (descriptor, offset, size),
790                 VmMemorySource::SharedMemory(shmem) => {
791                     let size = shmem.size();
792                     // Safe because we own shmem.
793                     let descriptor =
794                         unsafe { SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor()) };
795                     (descriptor, 0, size)
796                 }
797                 _ => bail!("unsupported source"),
798             };
799             let flags = VhostUserShmemMapMsgFlags::from(prot);
800             let msg =
801                 VhostUserShmemMapMsg::new(self.shmem_info.shmid, offset, fd_offset, size, flags);
802             self.conn
803                 .shmem_map(&msg, &descriptor)
804                 .context("failed to map memory")?;
805             size
806         };
807 
808         self.shmem_info.mapped_regions.insert(offset, size);
809         Ok(())
810     }
811 
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>812     fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
813         let size = self
814             .shmem_info
815             .mapped_regions
816             .remove(&offset)
817             .context("unknown offset")?;
818         let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
819         self.conn
820             .shmem_unmap(&msg)
821             .context("failed to map memory")
822             .map(|_| ())
823     }
824 }
825 
826 #[cfg(test)]
827 mod tests {
828     #[cfg(unix)]
829     use std::sync::mpsc::channel;
830     #[cfg(unix)]
831     use std::sync::Barrier;
832 
833     use anyhow::anyhow;
834     use anyhow::bail;
835     #[cfg(unix)]
836     use tempfile::Builder;
837     #[cfg(unix)]
838     use tempfile::TempDir;
839     use vmm_vhost::message::MasterReq;
840     use vmm_vhost::SlaveReqHandler;
841     use vmm_vhost::VhostUserSlaveReqHandler;
842     use zerocopy::AsBytes;
843     use zerocopy::FromBytes;
844 
845     use super::*;
846     use crate::virtio::vhost::user::vmm::VhostUserHandler;
847 
848     #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromBytes)]
849     #[repr(C, packed(4))]
850     struct FakeConfig {
851         x: u32,
852         y: u64,
853     }
854 
855     const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
856 
857     pub(super) struct FakeBackend {
858         avail_features: u64,
859         acked_features: u64,
860         acked_protocol_features: VhostUserProtocolFeatures,
861     }
862 
863     impl FakeBackend {
864         const MAX_QUEUE_NUM: usize = 16;
865 
new() -> Self866         pub(super) fn new() -> Self {
867             Self {
868                 avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
869                 acked_features: 0,
870                 acked_protocol_features: VhostUserProtocolFeatures::empty(),
871             }
872         }
873     }
874 
875     impl VhostUserBackend for FakeBackend {
max_queue_num(&self) -> usize876         fn max_queue_num(&self) -> usize {
877             Self::MAX_QUEUE_NUM
878         }
879 
features(&self) -> u64880         fn features(&self) -> u64 {
881             self.avail_features
882         }
883 
ack_features(&mut self, value: u64) -> anyhow::Result<()>884         fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
885             let unrequested_features = value & !self.avail_features;
886             if unrequested_features != 0 {
887                 bail!(
888                     "invalid protocol features are given: 0x{:x}",
889                     unrequested_features
890                 );
891             }
892             self.acked_features |= value;
893             Ok(())
894         }
895 
acked_features(&self) -> u64896         fn acked_features(&self) -> u64 {
897             self.acked_features
898         }
899 
protocol_features(&self) -> VhostUserProtocolFeatures900         fn protocol_features(&self) -> VhostUserProtocolFeatures {
901             VhostUserProtocolFeatures::CONFIG
902         }
903 
ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()>904         fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> {
905             let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
906                 "invalid protocol features are given: 0x{:x}",
907                 features
908             ))?;
909             let supported = self.protocol_features();
910             self.acked_protocol_features = features & supported;
911             Ok(())
912         }
913 
acked_protocol_features(&self) -> u64914         fn acked_protocol_features(&self) -> u64 {
915             self.acked_protocol_features.bits()
916         }
917 
read_config(&self, offset: u64, dst: &mut [u8])918         fn read_config(&self, offset: u64, dst: &mut [u8]) {
919             dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
920         }
921 
reset(&mut self)922         fn reset(&mut self) {}
923 
start_queue( &mut self, _idx: usize, _queue: Queue, _mem: GuestMemory, _doorbell: Doorbell, _kick_evt: Event, ) -> anyhow::Result<()>924         fn start_queue(
925             &mut self,
926             _idx: usize,
927             _queue: Queue,
928             _mem: GuestMemory,
929             _doorbell: Doorbell,
930             _kick_evt: Event,
931         ) -> anyhow::Result<()> {
932             Ok(())
933         }
934 
stop_queue(&mut self, _idx: usize)935         fn stop_queue(&mut self, _idx: usize) {}
936     }
937 
938     #[cfg(unix)]
temp_dir() -> TempDir939     fn temp_dir() -> TempDir {
940         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
941     }
942 
943     #[cfg(unix)]
944     #[test]
test_vhost_user_activate()945     fn test_vhost_user_activate() {
946         use std::os::unix::net::UnixStream;
947 
948         use vmm_vhost::connection::socket::Listener as SocketListener;
949         use vmm_vhost::SlaveListener;
950 
951         const QUEUES_NUM: usize = 2;
952 
953         let dir = temp_dir();
954         let mut path = dir.path().to_owned();
955         path.push("sock");
956         let listener = SocketListener::new(&path, true).unwrap();
957 
958         let vmm_bar = Arc::new(Barrier::new(2));
959         let dev_bar = vmm_bar.clone();
960 
961         let (tx, rx) = channel();
962 
963         std::thread::spawn(move || {
964             // VMM side
965             rx.recv().unwrap(); // Ensure the device is ready.
966 
967             let allow_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
968             let init_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
969             let allow_protocol_features = VhostUserProtocolFeatures::CONFIG;
970             let connection = UnixStream::connect(&path).unwrap();
971             let mut vmm_handler = VhostUserHandler::new_from_connection(
972                 connection,
973                 QUEUES_NUM as u64,
974                 allow_features,
975                 init_features,
976                 allow_protocol_features,
977             )
978             .unwrap();
979 
980             println!("read_config");
981             let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
982             vmm_handler.read_config(0, &mut buf).unwrap();
983             // Check if the obtained config data is correct.
984             let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
985             assert_eq!(config, FAKE_CONFIG_DATA);
986 
987             println!("set_mem_table");
988             let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
989             vmm_handler.set_mem_table(&mem).unwrap();
990 
991             for idx in 0..QUEUES_NUM {
992                 println!("activate_mem_table: queue_index={}", idx);
993                 let queue = Queue::new(0x10);
994                 let queue_evt = Event::new().unwrap();
995                 let irqfd = Event::new().unwrap();
996 
997                 vmm_handler
998                     .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd)
999                     .unwrap();
1000             }
1001 
1002             // The VMM side is supposed to stop before the device side.
1003             drop(vmm_handler);
1004 
1005             vmm_bar.wait();
1006         });
1007 
1008         // Device side
1009         let handler = std::sync::Mutex::new(DeviceRequestHandler::new(
1010             Box::new(FakeBackend::new()),
1011             Box::new(VhostUserRegularOps),
1012         ));
1013         let mut listener = SlaveListener::<SocketListener, _>::new(listener, handler).unwrap();
1014 
1015         // Notify listener is ready.
1016         tx.send(()).unwrap();
1017 
1018         let mut listener = listener.accept().unwrap().unwrap();
1019 
1020         // VhostUserHandler::new()
1021         handle_request(&mut listener).expect("set_owner");
1022         handle_request(&mut listener).expect("get_features");
1023         handle_request(&mut listener).expect("set_features");
1024         handle_request(&mut listener).expect("get_protocol_features");
1025         handle_request(&mut listener).expect("set_protocol_features");
1026 
1027         // VhostUserHandler::read_config()
1028         handle_request(&mut listener).expect("get_config");
1029 
1030         // VhostUserHandler::set_mem_table()
1031         handle_request(&mut listener).expect("set_mem_table");
1032 
1033         for _ in 0..QUEUES_NUM {
1034             // VhostUserHandler::activate_vring()
1035             handle_request(&mut listener).expect("set_vring_num");
1036             handle_request(&mut listener).expect("set_vring_addr");
1037             handle_request(&mut listener).expect("set_vring_base");
1038             handle_request(&mut listener).expect("set_vring_call");
1039             handle_request(&mut listener).expect("set_vring_kick");
1040             handle_request(&mut listener).expect("set_vring_enable");
1041         }
1042 
1043         dev_bar.wait();
1044 
1045         match handle_request(&mut listener) {
1046             Err(VhostError::ClientExit) => (),
1047             r => panic!("Err(ClientExit) was expected but {:?}", r),
1048         }
1049     }
1050 
vmm_handler_send_requests(vmm_handler: &mut VhostUserHandler, queues_num: usize)1051     pub(super) fn vmm_handler_send_requests(vmm_handler: &mut VhostUserHandler, queues_num: usize) {
1052         println!("read_config");
1053         let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1054         vmm_handler.read_config(0, &mut buf).unwrap();
1055         // Check if the obtained config data is correct.
1056         let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1057         assert_eq!(config, FAKE_CONFIG_DATA);
1058 
1059         println!("set_mem_table");
1060         let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1061         vmm_handler.set_mem_table(&mem).unwrap();
1062 
1063         for idx in 0..queues_num {
1064             println!("activate_mem_table: queue_index={}", idx);
1065             let queue = Queue::new(0x10);
1066             let queue_evt = Event::new().unwrap();
1067             let irqfd = Event::new().unwrap();
1068 
1069             vmm_handler
1070                 .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd)
1071                 .unwrap();
1072         }
1073     }
1074 
handle_request<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( handler: &mut SlaveReqHandler<S, E>, ) -> Result<(), VhostError>1075     fn handle_request<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>(
1076         handler: &mut SlaveReqHandler<S, E>,
1077     ) -> Result<(), VhostError> {
1078         let (hdr, files) = handler.recv_header()?;
1079         handler.process_message(hdr, files)
1080     }
1081 
test_handle_requests<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( req_handler: &mut SlaveReqHandler<S, E>, queues_num: usize, )1082     pub(super) fn test_handle_requests<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>(
1083         req_handler: &mut SlaveReqHandler<S, E>,
1084         queues_num: usize,
1085     ) {
1086         // VhostUserHandler::new()
1087         handle_request(req_handler).expect("set_owner");
1088         handle_request(req_handler).expect("get_features");
1089         handle_request(req_handler).expect("set_features");
1090         handle_request(req_handler).expect("get_protocol_features");
1091         handle_request(req_handler).expect("set_protocol_features");
1092 
1093         // VhostUserHandler::read_config()
1094         handle_request(req_handler).expect("get_config");
1095 
1096         // VhostUserHandler::set_mem_table()
1097         handle_request(req_handler).expect("set_mem_table");
1098 
1099         for _ in 0..queues_num {
1100             // VhostUserHandler::activate_vring()
1101             handle_request(req_handler).expect("set_vring_num");
1102             handle_request(req_handler).expect("set_vring_addr");
1103             handle_request(req_handler).expect("set_vring_base");
1104             handle_request(req_handler).expect("set_vring_call");
1105             handle_request(req_handler).expect("set_vring_kick");
1106             handle_request(req_handler).expect("set_vring_enable");
1107         }
1108     }
1109 }
1110