• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserBackend` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserBackend` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //!   /* fields */
20 //! }
21 //!
22 //! impl VhostUserBackend for MyBackend {
23 //!   /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //!   let backend = MyBackend { /* initialize fields */ };
28 //!   let handler = DeviceRequestHandler::new(backend);
29 //!   let socket = std::path::Path("/path/to/socket");
30 //!   let ex = cros_async::Executor::new()?;
31 //!
32 //!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //!     eprintln!("error happened: {}", e);
34 //!   }
35 //!   Ok(())
36 //! }
37 //! ```
38 //!
39 // Implementation note:
40 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
41 // protocol. DeviceRequestHandler implements the VhostUserSlaveReqHandlerMut trait from vmm_vhost,
42 // and includes some common code for setting up guest memory and managing partially configured
43 // vrings. DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request()
44 // when it becomes readable. handle_request() reads and parses the message and then calls one of the
45 // VhostUserSlaveReqHandlerMut trait methods. These dispatch back to the supplied VhostUserBackend
46 // implementation (this is what our devices implement).
47 
48 use base::AsRawDescriptor;
49 use std::convert::{From, TryFrom};
50 use std::fs::File;
51 use std::num::Wrapping;
52 use std::os::unix::io::AsRawFd;
53 use std::os::unix::net::UnixListener;
54 use std::path::Path;
55 use std::sync::Arc;
56 
57 use anyhow::{anyhow, bail, Context, Result};
58 use base::{
59     clear_fd_flags, error, info, Event, FromRawDescriptor, IntoRawDescriptor, SafeDescriptor,
60     SharedMemory, SharedMemoryUnix, UnlinkUnixListener,
61 };
62 use cros_async::{AsyncWrapper, Executor};
63 use sync::Mutex;
64 use vm_memory::{GuestAddress, GuestMemory, MemoryRegion};
65 use vmm_vhost::{
66     connection::vfio::{Endpoint as VfioEndpoint, Listener as VfioListener},
67     message::{
68         VhostUserConfigFlags, VhostUserInflight, VhostUserMemoryRegion, VhostUserProtocolFeatures,
69         VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags,
70         VhostUserVringState,
71     },
72     Protocol, SlaveListener, SlaveReqHandler,
73 };
74 
75 use vmm_vhost::{Error as VhostError, Result as VhostResult, VhostUserSlaveReqHandlerMut};
76 
77 use crate::vfio::{VfioDevice, VfioRegionAddr};
78 use crate::virtio::vhost::user::device::vvu::{
79     device::VvuDevice,
80     doorbell::DoorbellRegion,
81     pci::{VvuPciCaps, VvuPciDevice},
82 };
83 use crate::virtio::{Queue, SignalableInterrupt};
84 
85 /// An event to deliver an interrupt to the guest.
86 ///
87 /// Unlike `devices::Interrupt`, this doesn't support interrupt status and signal resampling.
88 // TODO(b/187487351): To avoid sending unnecessary events, we might want to support interrupt
89 // status. For this purpose, we need a mechanism to share interrupt status between the vmm and the
90 // device process.
91 pub struct CallEvent(Event);
92 
93 impl SignalableInterrupt for CallEvent {
signal(&self, _vector: u16, _interrupt_status_mask: u32)94     fn signal(&self, _vector: u16, _interrupt_status_mask: u32) {
95         self.0.write(1).unwrap();
96     }
97 
signal_config_changed(&self)98     fn signal_config_changed(&self) {} // TODO(dgreid)
99 
get_resample_evt(&self) -> Option<&Event>100     fn get_resample_evt(&self) -> Option<&Event> {
101         None
102     }
103 
do_interrupt_resample(&self)104     fn do_interrupt_resample(&self) {}
105 }
106 
107 impl From<File> for CallEvent {
from(file: File) -> Self108     fn from(file: File) -> Self {
109         // Safe because we own the file.
110         CallEvent(unsafe { Event::from_raw_descriptor(file.into_raw_descriptor()) })
111     }
112 }
113 
114 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
115 /// used to translate messages from the vmm to guest offsets.
116 #[derive(Default)]
117 pub struct MappingInfo {
118     pub vmm_addr: u64,
119     pub guest_phys: u64,
120     pub size: u64,
121 }
122 
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>123 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
124     for map in maps {
125         if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
126             return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
127         }
128     }
129     Err(VhostError::InvalidMessage)
130 }
131 
create_guest_memory( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>132 pub fn create_guest_memory(
133     contexts: &[VhostUserMemoryRegion],
134     files: Vec<File>,
135 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
136     let mut regions = Vec::with_capacity(files.len());
137     for (region, file) in contexts.iter().zip(files.into_iter()) {
138         let region = MemoryRegion::new_from_shm(
139             region.memory_size,
140             GuestAddress(region.guest_phys_addr),
141             region.mmap_offset,
142             Arc::new(SharedMemory::from_safe_descriptor(SafeDescriptor::from(file)).unwrap()),
143         )
144         .map_err(|e| {
145             error!("failed to create a memory region: {}", e);
146             VhostError::InvalidOperation
147         })?;
148         regions.push(region);
149     }
150     let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
151         error!("failed to create guest memory: {}", e);
152         VhostError::InvalidOperation
153     })?;
154 
155     let vmm_maps = contexts
156         .iter()
157         .map(|region| MappingInfo {
158             vmm_addr: region.user_addr,
159             guest_phys: region.guest_phys_addr,
160             size: region.memory_size,
161         })
162         .collect();
163     Ok((guest_mem, vmm_maps))
164 }
165 
create_vvu_guest_memory( vfio_dev: &VfioDevice, shared_mem_addr: &VfioRegionAddr, contexts: &[VhostUserMemoryRegion], ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>166 pub fn create_vvu_guest_memory(
167     vfio_dev: &VfioDevice,
168     shared_mem_addr: &VfioRegionAddr,
169     contexts: &[VhostUserMemoryRegion],
170 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
171     let file_offset = vfio_dev.get_offset_for_addr(shared_mem_addr).map_err(|e| {
172         error!("failed to get underlying file: {}", e);
173         VhostError::InvalidOperation
174     })?;
175 
176     let mut vmm_maps = Vec::with_capacity(contexts.len());
177     let mut regions = Vec::with_capacity(contexts.len());
178     let page_size = base::pagesize() as u64;
179     for region in contexts {
180         let offset = file_offset + region.mmap_offset;
181         assert_eq!(offset % page_size, 0);
182 
183         vmm_maps.push(MappingInfo {
184             vmm_addr: region.user_addr as u64,
185             guest_phys: region.guest_phys_addr as u64,
186             size: region.memory_size,
187         });
188 
189         let cloned_file = vfio_dev.dev_file().try_clone().map_err(|e| {
190             error!("failed to clone vfio device file: {}", e);
191             VhostError::InvalidOperation
192         })?;
193         let region = MemoryRegion::new_from_file(
194             region.memory_size,
195             GuestAddress(region.guest_phys_addr),
196             file_offset + region.mmap_offset,
197             Arc::new(cloned_file),
198         )
199         .map_err(|e| {
200             error!("failed to create a memory region: {}", e);
201             VhostError::InvalidOperation
202         })?;
203         regions.push(region);
204     }
205 
206     let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
207         error!("failed to create guest memory: {}", e);
208         VhostError::InvalidOperation
209     })?;
210 
211     Ok((guest_mem, vmm_maps))
212 }
213 
214 /// Trait for vhost-user backend.
215 pub trait VhostUserBackend
216 where
217     Self: Sized,
218     Self::Error: std::fmt::Display,
219 {
220     const MAX_QUEUE_NUM: usize;
221     const MAX_VRING_LEN: u16;
222 
223     /// Error type specific to this backend.
224     type Error;
225 
226     /// The set of feature bits that this backend supports.
features(&self) -> u64227     fn features(&self) -> u64;
228 
229     /// Acknowledges that this set of features should be enabled.
ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error>230     fn ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error>;
231 
232     /// Returns the set of enabled features.
acked_features(&self) -> u64233     fn acked_features(&self) -> u64;
234 
235     /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures236     fn protocol_features(&self) -> VhostUserProtocolFeatures;
237 
238     /// Acknowledges that this set of protocol features should be enabled.
ack_protocol_features(&mut self, _value: u64) -> std::result::Result<(), Self::Error>239     fn ack_protocol_features(&mut self, _value: u64) -> std::result::Result<(), Self::Error>;
240 
241     /// Returns the set of enabled protocol features.
acked_protocol_features(&self) -> u64242     fn acked_protocol_features(&self) -> u64;
243 
244     /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])245     fn read_config(&self, offset: u64, dst: &mut [u8]);
246 
247     /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])248     fn write_config(&self, _offset: u64, _data: &[u8]) {}
249 
250     /// Sets the channel for device-specific communication.
set_device_request_channel( &mut self, _channel: File, ) -> std::result::Result<(), Self::Error>251     fn set_device_request_channel(
252         &mut self,
253         _channel: File,
254     ) -> std::result::Result<(), Self::Error> {
255         Ok(())
256     }
257 
258     /// Indicates that the backend should start processing requests for virtio queue number `idx`.
259     /// This method must not block the current thread so device backends should either spawn an
260     /// async task or another thread to handle messages from the Queue.
start_queue( &mut self, idx: usize, queue: Queue, mem: GuestMemory, doorbell: Arc<Mutex<Doorbell>>, kick_evt: Event, ) -> std::result::Result<(), Self::Error>261     fn start_queue(
262         &mut self,
263         idx: usize,
264         queue: Queue,
265         mem: GuestMemory,
266         doorbell: Arc<Mutex<Doorbell>>,
267         kick_evt: Event,
268     ) -> std::result::Result<(), Self::Error>;
269 
270     /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
stop_queue(&mut self, idx: usize)271     fn stop_queue(&mut self, idx: usize);
272 
273     /// Resets the vhost-user backend.
reset(&mut self)274     fn reset(&mut self);
275 }
276 
277 pub enum Doorbell {
278     Call(CallEvent),
279     Vfio(DoorbellRegion),
280 }
281 
282 impl SignalableInterrupt for Doorbell {
signal(&self, vector: u16, interrupt_status_mask: u32)283     fn signal(&self, vector: u16, interrupt_status_mask: u32) {
284         match &self {
285             Self::Call(evt) => evt.signal(vector, interrupt_status_mask),
286             Self::Vfio(evt) => evt.signal(vector, interrupt_status_mask),
287         }
288     }
289 
signal_config_changed(&self)290     fn signal_config_changed(&self) {
291         match &self {
292             Self::Call(evt) => evt.signal_config_changed(),
293             Self::Vfio(evt) => evt.signal_config_changed(),
294         }
295     }
296 
get_resample_evt(&self) -> Option<&Event>297     fn get_resample_evt(&self) -> Option<&Event> {
298         match &self {
299             Self::Call(evt) => evt.get_resample_evt(),
300             Self::Vfio(evt) => evt.get_resample_evt(),
301         }
302     }
303 
do_interrupt_resample(&self)304     fn do_interrupt_resample(&self) {
305         match &self {
306             Self::Call(evt) => evt.do_interrupt_resample(),
307             Self::Vfio(evt) => evt.do_interrupt_resample(),
308         }
309     }
310 }
311 
312 /// A virtio ring entry.
313 struct Vring {
314     queue: Queue,
315     doorbell: Option<Arc<Mutex<Doorbell>>>,
316     enabled: bool,
317 }
318 
319 impl Vring {
new(max_size: u16) -> Self320     fn new(max_size: u16) -> Self {
321         Self {
322             queue: Queue::new(max_size),
323             doorbell: None,
324             enabled: false,
325         }
326     }
327 
reset(&mut self)328     fn reset(&mut self) {
329         self.queue.reset();
330         self.doorbell = None;
331         self.enabled = false;
332     }
333 }
334 
335 pub(super) enum HandlerType {
336     VhostUser,
337     Vvu {
338         vfio_dev: Arc<VfioDevice>,
339         caps: VvuPciCaps,
340         notification_evts: Vec<Event>,
341     },
342 }
343 
344 impl Default for HandlerType {
default() -> Self345     fn default() -> Self {
346         Self::VhostUser
347     }
348 }
349 
350 /// Structure to have an event loop for interaction between a VMM and `VhostUserBackend`.
351 pub struct DeviceRequestHandler<B>
352 where
353     B: 'static + VhostUserBackend,
354 {
355     vrings: Vec<Vring>,
356     owned: bool,
357     vmm_maps: Option<Vec<MappingInfo>>,
358     mem: Option<GuestMemory>,
359     backend: B,
360 
361     handler_type: HandlerType,
362 }
363 
364 impl<B> DeviceRequestHandler<B>
365 where
366     B: 'static + VhostUserBackend,
367 {
368     /// Creates the vhost-user handler instance for `backend`.
new(backend: B) -> Self369     pub fn new(backend: B) -> Self {
370         let mut vrings = Vec::with_capacity(B::MAX_QUEUE_NUM);
371         for _ in 0..B::MAX_QUEUE_NUM {
372             vrings.push(Vring::new(B::MAX_VRING_LEN as u16));
373         }
374 
375         DeviceRequestHandler {
376             vrings,
377             owned: false,
378             vmm_maps: None,
379             mem: None,
380             backend,
381             handler_type: Default::default(), // For vvu, this field will be overwritten later.
382         }
383     }
384 
385     /// Creates a listening socket at `socket` and handles incoming messages from the VMM, which are
386     /// dispatched to the device backend via the `VhostUserBackend` trait methods.
run<P: AsRef<Path>>(self, socket: P, ex: &Executor) -> Result<()>387     pub async fn run<P: AsRef<Path>>(self, socket: P, ex: &Executor) -> Result<()> {
388         let listener = UnixListener::bind(socket)
389             .map(UnlinkUnixListener)
390             .context("failed to create a UNIX domain socket listener")?;
391         return self.run_with_listener(listener, ex).await;
392     }
393 
394     /// Attaches to an already bound socket via `listener` and handles incoming messages from the
395     /// VMM, which are dispatched to the device backend via the `VhostUserBackend` trait methods.
run_with_listener( self, listener: UnlinkUnixListener, ex: &Executor, ) -> Result<()>396     pub async fn run_with_listener(
397         self,
398         listener: UnlinkUnixListener,
399         ex: &Executor,
400     ) -> Result<()> {
401         let (socket, _) = ex
402             .spawn_blocking(move || {
403                 listener
404                     .accept()
405                     .context("failed to accept an incoming connection")
406             })
407             .await?;
408         let mut req_handler =
409             SlaveReqHandler::from_stream(socket, Arc::new(std::sync::Mutex::new(self)));
410         let h = SafeDescriptor::try_from(&req_handler as &dyn AsRawDescriptor)
411             .map(AsyncWrapper::new)
412             .expect("failed to get safe descriptor for handler");
413         let handler_source = ex
414             .async_from(h)
415             .context("failed to create an async source")?;
416 
417         loop {
418             handler_source
419                 .wait_readable()
420                 .await
421                 .context("failed to wait for the handler socket to become readable")?;
422             match req_handler.handle_request() {
423                 Ok(()) => (),
424                 Err(VhostError::ClientExit) => {
425                     info!("vhost-user connection closed");
426                     // Exit as the client closed the connection.
427                     return Ok(());
428                 }
429                 Err(e) => {
430                     bail!("failed to handle a vhost-user request: {}", e);
431                 }
432             };
433         }
434     }
435 
436     /// Starts listening virtio-vhost-user device with VFIO to handle incoming vhost-user messages
437     /// forwarded by it.
run_vvu(mut self, mut device: VvuPciDevice, ex: &Executor) -> Result<()>438     pub async fn run_vvu(mut self, mut device: VvuPciDevice, ex: &Executor) -> Result<()> {
439         self.handler_type = HandlerType::Vvu {
440             vfio_dev: Arc::clone(&device.vfio_dev),
441             caps: device.caps.clone(),
442             notification_evts: std::mem::take(&mut device.notification_evts),
443         };
444         let driver = VvuDevice::new(device);
445 
446         let mut listener = VfioListener::new(driver)
447             .map_err(|e| anyhow!("failed to create a VFIO listener: {}", e))
448             .and_then(|l| {
449                 SlaveListener::<VfioEndpoint<_, _>, _>::new(
450                     l,
451                     Arc::new(std::sync::Mutex::new(self)),
452                 )
453                 .map_err(|e| anyhow!("failed to create SlaveListener: {}", e))
454             })?;
455 
456         let mut req_handler = listener
457             .accept()
458             .map_err(|e| anyhow!("failed to accept VFIO connection: {}", e))?
459             .expect("vvu proxy is unavailable via VFIO");
460 
461         let h = SafeDescriptor::try_from(&req_handler as &dyn AsRawDescriptor)
462             .map(AsyncWrapper::new)
463             .expect("failed to get safe descriptor for handler");
464         let handler_source = ex
465             .async_from(h)
466             .context("failed to create asyn handler source")?;
467 
468         let done = async move {
469             loop {
470                 // Wait for requests from the sibling.
471                 // `read_u64()` returns the number of requests arrived.
472                 let count = handler_source
473                     .read_u64()
474                     .await
475                     .context("failed to wait for handler source")?;
476                 for _ in 0..count {
477                     req_handler
478                         .handle_request()
479                         .context("failed to handle request")?;
480                 }
481             }
482         };
483         match ex.run_until(done) {
484             Ok(Ok(())) => Ok(()),
485             Ok(Err(e)) => Err(e),
486             Err(e) => Err(e).context("executor error"),
487         }
488     }
489 }
490 
491 impl<B: VhostUserBackend> VhostUserSlaveReqHandlerMut for DeviceRequestHandler<B> {
protocol(&self) -> Protocol492     fn protocol(&self) -> Protocol {
493         match self.handler_type {
494             HandlerType::VhostUser => Protocol::Regular,
495             HandlerType::Vvu { .. } => Protocol::Virtio,
496         }
497     }
498 
set_owner(&mut self) -> VhostResult<()>499     fn set_owner(&mut self) -> VhostResult<()> {
500         if self.owned {
501             return Err(VhostError::InvalidOperation);
502         }
503         self.owned = true;
504         Ok(())
505     }
506 
reset_owner(&mut self) -> VhostResult<()>507     fn reset_owner(&mut self) -> VhostResult<()> {
508         self.owned = false;
509         self.backend.reset();
510         Ok(())
511     }
512 
get_features(&mut self) -> VhostResult<u64>513     fn get_features(&mut self) -> VhostResult<u64> {
514         let features = self.backend.features();
515         Ok(features)
516     }
517 
set_features(&mut self, features: u64) -> VhostResult<()>518     fn set_features(&mut self, features: u64) -> VhostResult<()> {
519         if !self.owned {
520             return Err(VhostError::InvalidOperation);
521         }
522 
523         if (features & !(self.backend.features())) != 0 {
524             return Err(VhostError::InvalidParam);
525         }
526 
527         if let Err(e) = self.backend.ack_features(features) {
528             error!("failed to acknowledge features 0x{:x}: {}", features, e);
529             return Err(VhostError::InvalidOperation);
530         }
531 
532         // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
533         // enabled state.
534         // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
535         // disabled state.
536         // Client must not pass data to/from the backend until ring is enabled by
537         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
538         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
539         let acked_features = self.backend.acked_features();
540         let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0;
541         for v in &mut self.vrings {
542             v.enabled = vring_enabled;
543         }
544 
545         Ok(())
546     }
547 
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>548     fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
549         Ok(self.backend.protocol_features())
550     }
551 
set_protocol_features(&mut self, features: u64) -> VhostResult<()>552     fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
553         if let Err(e) = self.backend.ack_protocol_features(features) {
554             error!("failed to set protocol features 0x{:x}: {}", features, e);
555             return Err(VhostError::InvalidOperation);
556         }
557         Ok(())
558     }
559 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>560     fn set_mem_table(
561         &mut self,
562         contexts: &[VhostUserMemoryRegion],
563         files: Vec<File>,
564     ) -> VhostResult<()> {
565         let (guest_mem, vmm_maps) = match &self.handler_type {
566             HandlerType::VhostUser => {
567                 if files.len() != contexts.len() {
568                     return Err(VhostError::InvalidParam);
569                 }
570                 create_guest_memory(contexts, files)?
571             }
572             HandlerType::Vvu {
573                 vfio_dev: device,
574                 caps,
575                 ..
576             } => {
577                 // virtio-vhost-user doesn't pass FDs.
578                 if !files.is_empty() {
579                     return Err(VhostError::InvalidParam);
580                 }
581                 create_vvu_guest_memory(device.as_ref(), caps.shared_mem_cfg_addr(), contexts)?
582             }
583         };
584 
585         self.mem = Some(guest_mem);
586         self.vmm_maps = Some(vmm_maps);
587         Ok(())
588     }
589 
get_queue_num(&mut self) -> VhostResult<u64>590     fn get_queue_num(&mut self) -> VhostResult<u64> {
591         Ok(self.vrings.len() as u64)
592     }
593 
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>594     fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
595         if index as usize >= self.vrings.len() || num == 0 || num > B::MAX_VRING_LEN.into() {
596             return Err(VhostError::InvalidParam);
597         }
598         self.vrings[index as usize].queue.size = num as u16;
599 
600         Ok(())
601     }
602 
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>603     fn set_vring_addr(
604         &mut self,
605         index: u32,
606         _flags: VhostUserVringAddrFlags,
607         descriptor: u64,
608         used: u64,
609         available: u64,
610         _log: u64,
611     ) -> VhostResult<()> {
612         if index as usize >= self.vrings.len() {
613             return Err(VhostError::InvalidParam);
614         }
615 
616         let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
617         let vring = &mut self.vrings[index as usize];
618         vring.queue.desc_table = vmm_va_to_gpa(vmm_maps, descriptor)?;
619         vring.queue.avail_ring = vmm_va_to_gpa(vmm_maps, available)?;
620         vring.queue.used_ring = vmm_va_to_gpa(vmm_maps, used)?;
621 
622         Ok(())
623     }
624 
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>625     fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
626         if index as usize >= self.vrings.len() || base >= B::MAX_VRING_LEN.into() {
627             return Err(VhostError::InvalidParam);
628         }
629 
630         let vring = &mut self.vrings[index as usize];
631         vring.queue.next_avail = Wrapping(base as u16);
632         vring.queue.next_used = Wrapping(base as u16);
633 
634         Ok(())
635     }
636 
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>637     fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
638         if index as usize >= self.vrings.len() {
639             return Err(VhostError::InvalidParam);
640         }
641 
642         // Quotation from vhost-user spec:
643         // Client must start ring upon receiving a kick (that is, detecting
644         // that file descriptor is readable) on the descriptor specified by
645         // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
646         // VHOST_USER_GET_VRING_BASE.
647         self.backend.stop_queue(index as usize);
648 
649         let vring = &mut self.vrings[index as usize];
650         vring.reset();
651 
652         Ok(VhostUserVringState::new(
653             index,
654             vring.queue.next_avail.0 as u32,
655         ))
656     }
657 
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>658     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
659         if index as usize >= self.vrings.len() {
660             return Err(VhostError::InvalidParam);
661         }
662 
663         let vring = &mut self.vrings[index as usize];
664         if vring.queue.ready {
665             error!("kick fd cannot replaced after queue is started");
666             return Err(VhostError::InvalidOperation);
667         }
668 
669         let kick_evt = match &self.handler_type {
670             HandlerType::VhostUser => {
671                 let file = file.ok_or(VhostError::InvalidParam)?;
672                 // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
673                 // values via `next_val()` later.
674                 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
675                     error!("failed to remove O_NONBLOCK for kick fd: {}", e);
676                     return Err(VhostError::InvalidParam);
677                 }
678 
679                 // Safe because we own the file.
680                 unsafe { Event::from_raw_descriptor(file.into_raw_descriptor()) }
681             }
682             HandlerType::Vvu {
683                 notification_evts, ..
684             } => {
685                 if file.is_some() {
686                     return Err(VhostError::InvalidParam);
687                 }
688                 notification_evts[index as usize].try_clone().map_err(|e| {
689                     error!("failed to clone notification_evts[{}]: {}", index, e);
690                     VhostError::InvalidOperation
691                 })?
692             }
693         };
694 
695         let vring = &mut self.vrings[index as usize];
696         vring.queue.ready = true;
697 
698         let queue = vring.queue.clone();
699         let doorbell = vring
700             .doorbell
701             .as_ref()
702             .ok_or(VhostError::InvalidOperation)?;
703         let mem = self
704             .mem
705             .as_ref()
706             .cloned()
707             .ok_or(VhostError::InvalidOperation)?;
708 
709         if let Err(e) =
710             self.backend
711                 .start_queue(index as usize, queue, mem, Arc::clone(doorbell), kick_evt)
712         {
713             error!("Failed to start queue {}: {}", index, e);
714             return Err(VhostError::SlaveInternalError);
715         }
716 
717         Ok(())
718     }
719 
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>720     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
721         if index as usize >= self.vrings.len() {
722             return Err(VhostError::InvalidParam);
723         }
724 
725         let doorbell = match &self.handler_type {
726             HandlerType::VhostUser => {
727                 let file = file.ok_or(VhostError::InvalidParam)?;
728                 Doorbell::Call(CallEvent::try_from(file).map_err(|_| {
729                     error!("failed to convert callfd to CallSignal");
730                     VhostError::InvalidParam
731                 })?)
732             }
733             HandlerType::Vvu {
734                 vfio_dev: device,
735                 caps,
736                 ..
737             } => {
738                 let base = caps.doorbell_base_addr();
739                 let addr = VfioRegionAddr {
740                     index: base.index,
741                     addr: base.addr + (index as u64 * caps.doorbell_off_multiplier() as u64),
742                 };
743                 Doorbell::Vfio(DoorbellRegion {
744                     vfio: Arc::clone(device),
745                     index,
746                     addr,
747                 })
748             }
749         };
750 
751         match &self.vrings[index as usize].doorbell {
752             None => {
753                 self.vrings[index as usize].doorbell = Some(Arc::new(Mutex::new(doorbell)));
754             }
755             Some(cell) => {
756                 let mut evt = cell.lock();
757                 *evt = doorbell;
758             }
759         }
760 
761         Ok(())
762     }
763 
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>764     fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
765         // TODO
766         Ok(())
767     }
768 
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>769     fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
770         if index as usize >= self.vrings.len() {
771             return Err(VhostError::InvalidParam);
772         }
773 
774         // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
775         // has been negotiated.
776         if self.backend.acked_features() & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
777             return Err(VhostError::InvalidOperation);
778         }
779 
780         // Slave must not pass data to/from the backend until ring is
781         // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
782         // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
783         // with parameter 0.
784         self.vrings[index as usize].enabled = enable;
785 
786         Ok(())
787     }
788 
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>789     fn get_config(
790         &mut self,
791         offset: u32,
792         size: u32,
793         _flags: VhostUserConfigFlags,
794     ) -> VhostResult<Vec<u8>> {
795         if offset >= size {
796             return Err(VhostError::InvalidParam);
797         }
798 
799         let mut data = vec![0; size as usize];
800         self.backend.read_config(u64::from(offset), &mut data);
801         Ok(data)
802     }
803 
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>804     fn set_config(
805         &mut self,
806         offset: u32,
807         buf: &[u8],
808         _flags: VhostUserConfigFlags,
809     ) -> VhostResult<()> {
810         self.backend.write_config(u64::from(offset), buf);
811         Ok(())
812     }
813 
set_slave_req_fd(&mut self, fd: File)814     fn set_slave_req_fd(&mut self, fd: File) {
815         if let Err(e) = self.backend.set_device_request_channel(fd) {
816             error!("failed to set device request channel: {}", e);
817         }
818     }
819 
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>820     fn get_inflight_fd(
821         &mut self,
822         _inflight: &VhostUserInflight,
823     ) -> VhostResult<(VhostUserInflight, File)> {
824         unimplemented!("get_inflight_fd");
825     }
826 
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>827     fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
828         unimplemented!("set_inflight_fd");
829     }
830 
get_max_mem_slots(&mut self) -> VhostResult<u64>831     fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
832         //TODO
833         Ok(0)
834     }
835 
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>836     fn add_mem_region(
837         &mut self,
838         _region: &VhostUserSingleMemoryRegion,
839         _fd: File,
840     ) -> VhostResult<()> {
841         //TODO
842         Ok(())
843     }
844 
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>845     fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
846         //TODO
847         Ok(())
848     }
849 }
850 
851 #[cfg(test)]
852 mod tests {
853     use super::*;
854 
855     use std::sync::mpsc::channel;
856     use std::sync::Barrier;
857 
858     use anyhow::{anyhow, bail};
859     use data_model::DataInit;
860     use tempfile::{Builder, TempDir};
861 
862     use crate::virtio::vhost::user::vmm::VhostUserHandler;
863 
864     #[derive(Clone, Copy, Debug, PartialEq, Eq)]
865     #[repr(C)]
866     struct FakeConfig {
867         x: u32,
868         y: u64,
869     }
870 
871     unsafe impl DataInit for FakeConfig {}
872 
873     const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
874 
875     struct FakeBackend {
876         avail_features: u64,
877         acked_features: u64,
878         acked_protocol_features: VhostUserProtocolFeatures,
879     }
880 
881     impl FakeBackend {
new() -> Self882         fn new() -> Self {
883             Self {
884                 avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
885                 acked_features: 0,
886                 acked_protocol_features: VhostUserProtocolFeatures::empty(),
887             }
888         }
889     }
890 
891     impl VhostUserBackend for FakeBackend {
892         const MAX_QUEUE_NUM: usize = 16;
893         const MAX_VRING_LEN: u16 = 256;
894 
895         type Error = anyhow::Error;
896 
features(&self) -> u64897         fn features(&self) -> u64 {
898             self.avail_features
899         }
900 
ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error>901         fn ack_features(&mut self, value: u64) -> std::result::Result<(), Self::Error> {
902             let unrequested_features = value & !self.avail_features;
903             if unrequested_features != 0 {
904                 bail!(
905                     "invalid protocol features are given: 0x{:x}",
906                     unrequested_features
907                 );
908             }
909             self.acked_features |= value;
910             Ok(())
911         }
912 
acked_features(&self) -> u64913         fn acked_features(&self) -> u64 {
914             self.acked_features
915         }
916 
protocol_features(&self) -> VhostUserProtocolFeatures917         fn protocol_features(&self) -> VhostUserProtocolFeatures {
918             VhostUserProtocolFeatures::CONFIG
919         }
920 
ack_protocol_features(&mut self, features: u64) -> std::result::Result<(), Self::Error>921         fn ack_protocol_features(&mut self, features: u64) -> std::result::Result<(), Self::Error> {
922             let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
923                 "invalid protocol features are given: 0x{:x}",
924                 features
925             ))?;
926             let supported = self.protocol_features();
927             self.acked_protocol_features = features & supported;
928             Ok(())
929         }
930 
acked_protocol_features(&self) -> u64931         fn acked_protocol_features(&self) -> u64 {
932             self.acked_protocol_features.bits()
933         }
934 
read_config(&self, offset: u64, dst: &mut [u8])935         fn read_config(&self, offset: u64, dst: &mut [u8]) {
936             dst.copy_from_slice(&FAKE_CONFIG_DATA.as_slice()[offset as usize..]);
937         }
938 
reset(&mut self)939         fn reset(&mut self) {}
940 
start_queue( &mut self, _idx: usize, _queue: Queue, _mem: GuestMemory, _doorbell: Arc<Mutex<Doorbell>>, _kick_evt: Event, ) -> std::result::Result<(), Self::Error>941         fn start_queue(
942             &mut self,
943             _idx: usize,
944             _queue: Queue,
945             _mem: GuestMemory,
946             _doorbell: Arc<Mutex<Doorbell>>,
947             _kick_evt: Event,
948         ) -> std::result::Result<(), Self::Error> {
949             Ok(())
950         }
951 
stop_queue(&mut self, _idx: usize)952         fn stop_queue(&mut self, _idx: usize) {}
953     }
954 
temp_dir() -> TempDir955     fn temp_dir() -> TempDir {
956         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
957     }
958 
959     #[test]
test_vhost_user_activate()960     fn test_vhost_user_activate() {
961         use vmm_vhost::{
962             connection::socket::{Endpoint as SocketEndpoint, Listener as SocketListener},
963             SlaveListener,
964         };
965 
966         const QUEUES_NUM: usize = 2;
967 
968         let dir = temp_dir();
969         let mut path = dir.path().to_owned();
970         path.push("sock");
971         let listener = SocketListener::new(&path, true).unwrap();
972 
973         let vmm_bar = Arc::new(Barrier::new(2));
974         let dev_bar = vmm_bar.clone();
975 
976         let (tx, rx) = channel();
977 
978         std::thread::spawn(move || {
979             // VMM side
980             rx.recv().unwrap(); // Ensure the device is ready.
981 
982             let allow_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
983             let init_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
984             let allow_protocol_features = VhostUserProtocolFeatures::CONFIG;
985             let mut vmm_handler = VhostUserHandler::new_from_path(
986                 &path,
987                 QUEUES_NUM as u64,
988                 allow_features,
989                 init_features,
990                 allow_protocol_features,
991             )
992             .unwrap();
993 
994             println!("read_config");
995             let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
996             vmm_handler.read_config::<FakeConfig>(0, &mut buf).unwrap();
997             // Check if the obtained config data is correct.
998             let config = FakeConfig::from_slice(&buf).unwrap();
999             assert_eq!(*config, FAKE_CONFIG_DATA);
1000 
1001             println!("set_mem_table");
1002             let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1003             vmm_handler.set_mem_table(&mem).unwrap();
1004 
1005             for idx in 0..QUEUES_NUM {
1006                 println!("activate_mem_table: queue_index={}", idx);
1007                 let queue = Queue::new(0x10);
1008                 let queue_evt = Event::new().unwrap();
1009                 let irqfd = Event::new().unwrap();
1010 
1011                 vmm_handler
1012                     .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd)
1013                     .unwrap();
1014             }
1015 
1016             // The VMM side is supposed to stop before the device side.
1017             drop(vmm_handler);
1018 
1019             vmm_bar.wait();
1020         });
1021 
1022         // Device side
1023         let handler = Arc::new(std::sync::Mutex::new(DeviceRequestHandler::new(
1024             FakeBackend::new(),
1025         )));
1026         let mut listener = SlaveListener::<SocketEndpoint<_>, _>::new(listener, handler).unwrap();
1027 
1028         // Notify listener is ready.
1029         tx.send(()).unwrap();
1030 
1031         let mut listener = listener.accept().unwrap().unwrap();
1032 
1033         // VhostUserHandler::new()
1034         listener.handle_request().expect("set_owner");
1035         listener.handle_request().expect("get_features");
1036         listener.handle_request().expect("set_features");
1037         listener.handle_request().expect("get_protocol_features");
1038         listener.handle_request().expect("set_protocol_features");
1039 
1040         // VhostUserHandler::read_config()
1041         listener.handle_request().expect("get_config");
1042 
1043         // VhostUserHandler::set_mem_table()
1044         listener.handle_request().expect("set_mem_table");
1045 
1046         for _ in 0..QUEUES_NUM {
1047             // VhostUserHandler::activate_vring()
1048             listener.handle_request().expect("set_vring_num");
1049             listener.handle_request().expect("set_vring_addr");
1050             listener.handle_request().expect("set_vring_base");
1051             listener.handle_request().expect("set_vring_call");
1052             listener.handle_request().expect("set_vring_kick");
1053             listener.handle_request().expect("set_vring_enable");
1054         }
1055 
1056         dev_bar.wait();
1057 
1058         match listener.handle_request() {
1059             Err(VhostError::ClientExit) => (),
1060             r => panic!("Err(ClientExit) was expected but {:?}", r),
1061         }
1062     }
1063 }
1064