• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //!   /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //!   /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //!   let backend = MyBackend { /* initialize fields */ };
28 //!   let handler = DeviceRequestHandler::new(backend);
29 //!   let socket = std::path::Path("/path/to/socket");
30 //!   let ex = cros_async::Executor::new()?;
31 //!
32 //!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //!     eprintln!("error happened: {}", e);
34 //!   }
35 //!   Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46 
47 pub(super) mod sys;
48 
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56 
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::warn;
63 use base::Event;
64 use base::FromRawDescriptor;
65 use base::IntoRawDescriptor;
66 use base::Protection;
67 use base::SafeDescriptor;
68 use base::SharedMemory;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use sync::Mutex;
74 use thiserror::Error as ThisError;
75 use vm_control::VmMemorySource;
76 use vm_memory::GuestAddress;
77 use vm_memory::GuestMemory;
78 use vm_memory::MemoryRegion;
79 use vmm_vhost::message::VhostSharedMemoryRegion;
80 use vmm_vhost::message::VhostUserConfigFlags;
81 use vmm_vhost::message::VhostUserExternalMapMsg;
82 use vmm_vhost::message::VhostUserGpuMapMsg;
83 use vmm_vhost::message::VhostUserInflight;
84 use vmm_vhost::message::VhostUserMemoryRegion;
85 use vmm_vhost::message::VhostUserProtocolFeatures;
86 use vmm_vhost::message::VhostUserShmemMapMsg;
87 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
88 use vmm_vhost::message::VhostUserShmemUnmapMsg;
89 use vmm_vhost::message::VhostUserSingleMemoryRegion;
90 use vmm_vhost::message::VhostUserVringAddrFlags;
91 use vmm_vhost::message::VhostUserVringState;
92 use vmm_vhost::BackendReq;
93 use vmm_vhost::Connection;
94 use vmm_vhost::Error as VhostError;
95 use vmm_vhost::Frontend;
96 use vmm_vhost::FrontendClient;
97 use vmm_vhost::Result as VhostResult;
98 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
99 
100 use crate::virtio::Interrupt;
101 use crate::virtio::Queue;
102 use crate::virtio::QueueConfig;
103 use crate::virtio::SharedMemoryMapper;
104 use crate::virtio::SharedMemoryRegion;
105 
106 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
107 /// used to translate messages from the vmm to guest offsets.
108 #[derive(Default)]
109 pub struct MappingInfo {
110     pub vmm_addr: u64,
111     pub guest_phys: u64,
112     pub size: u64,
113 }
114 
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>115 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
116     for map in maps {
117         if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
118             return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
119         }
120     }
121     Err(VhostError::InvalidMessage)
122 }
123 
124 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
125 ///
126 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
127 /// is designed to follow crosvm conventions for implementing devices.
128 pub trait VhostUserDevice {
129     /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize130     fn max_queue_num(&self) -> usize;
131 
132     /// The set of feature bits that this backend supports.
features(&self) -> u64133     fn features(&self) -> u64;
134 
135     /// Acknowledges that this set of features should be enabled.
ack_features(&mut self, value: u64) -> anyhow::Result<()>136     fn ack_features(&mut self, value: u64) -> anyhow::Result<()>;
137 
138     /// Returns the set of enabled features.
acked_features(&self) -> u64139     fn acked_features(&self) -> u64;
140 
141     /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures142     fn protocol_features(&self) -> VhostUserProtocolFeatures;
143 
144     /// Acknowledges that this set of protocol features should be enabled.
ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>145     fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>;
146 
147     /// Returns the set of enabled protocol features.
acked_protocol_features(&self) -> u64148     fn acked_protocol_features(&self) -> u64;
149 
150     /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])151     fn read_config(&self, offset: u64, dst: &mut [u8]);
152 
153     /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])154     fn write_config(&self, _offset: u64, _data: &[u8]) {}
155 
156     /// Indicates that the backend should start processing requests for virtio queue number `idx`.
157     /// This method must not block the current thread so device backends should either spawn an
158     /// async task or another thread to handle messages from the Queue.
start_queue( &mut self, idx: usize, queue: Queue, mem: GuestMemory, doorbell: Interrupt, ) -> anyhow::Result<()>159     fn start_queue(
160         &mut self,
161         idx: usize,
162         queue: Queue,
163         mem: GuestMemory,
164         doorbell: Interrupt,
165     ) -> anyhow::Result<()>;
166 
167     /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
168     /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
169     /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>170     fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
171 
172     /// Resets the vhost-user backend.
reset(&mut self)173     fn reset(&mut self);
174 
175     /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>176     fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
177         None
178     }
179 
180     /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
181     /// handling.
182     ///
183     /// The backend is given an `Arc` instead of full ownership so that the framework can also use
184     /// the connection.
185     ///
186     /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
187     /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)188     fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {
189         error!("set_backend_req_connection is not implemented");
190     }
191 
192     /// Used to stop non queue workers that `VhostUserDevice::stop_queue` can't stop. May or may
193     /// not also stop all queue workers.
stop_non_queue_workers(&mut self) -> anyhow::Result<()>194     fn stop_non_queue_workers(&mut self) -> anyhow::Result<()> {
195         error!("sleep not implemented for vhost user device");
196         // TODO(rizhang): Return error once basic devices support this.
197         Ok(())
198     }
199 
200     /// Snapshot device and return serialized bytes.
snapshot(&self) -> anyhow::Result<Vec<u8>>201     fn snapshot(&self) -> anyhow::Result<Vec<u8>> {
202         error!("snapshot not implemented for vhost user device");
203         // TODO(rizhang): Return error once basic devices support this.
204         Ok(Vec::new())
205     }
206 
restore(&mut self, _data: Vec<u8>) -> anyhow::Result<()>207     fn restore(&mut self, _data: Vec<u8>) -> anyhow::Result<()> {
208         error!("restore not implemented for vhost user device");
209         // TODO(rizhang): Return error once basic devices support this.
210         Ok(())
211     }
212 }
213 
214 /// A virtio ring entry.
215 struct Vring {
216     // The queue config. This doesn't get mutated by the queue workers.
217     queue: QueueConfig,
218     doorbell: Option<Interrupt>,
219     enabled: bool,
220     // Active queue that is only `Some` when the device is sleeping.
221     paused_queue: Option<Queue>,
222 }
223 
224 #[derive(Serialize, Deserialize)]
225 struct VringSnapshot {
226     // Snapshot of queue config.
227     queue: serde_json::Value,
228     // Snapshot of the activated queue state.
229     paused_queue: Option<serde_json::Value>,
230     enabled: bool,
231 }
232 
233 impl Vring {
new(max_size: u16, features: u64) -> Self234     fn new(max_size: u16, features: u64) -> Self {
235         Self {
236             queue: QueueConfig::new(max_size, features),
237             doorbell: None,
238             enabled: false,
239             paused_queue: None,
240         }
241     }
242 
reset(&mut self)243     fn reset(&mut self) {
244         self.queue.reset();
245         self.doorbell = None;
246         self.enabled = false;
247         self.paused_queue = None;
248     }
249 
snapshot(&self) -> anyhow::Result<VringSnapshot>250     fn snapshot(&self) -> anyhow::Result<VringSnapshot> {
251         Ok(VringSnapshot {
252             queue: self.queue.snapshot()?,
253             enabled: self.enabled,
254             paused_queue: self
255                 .paused_queue
256                 .as_ref()
257                 .map(Queue::snapshot)
258                 .transpose()?,
259         })
260     }
261 
restore( &mut self, vring_snapshot: VringSnapshot, mem: &GuestMemory, event: Option<Event>, ) -> anyhow::Result<()>262     fn restore(
263         &mut self,
264         vring_snapshot: VringSnapshot,
265         mem: &GuestMemory,
266         event: Option<Event>,
267     ) -> anyhow::Result<()> {
268         self.queue.restore(vring_snapshot.queue)?;
269         self.enabled = vring_snapshot.enabled;
270         self.paused_queue = vring_snapshot
271             .paused_queue
272             .map(|value| {
273                 Queue::restore(
274                     &self.queue,
275                     value,
276                     mem,
277                     event.context("missing queue event")?,
278                 )
279             })
280             .transpose()?;
281         Ok(())
282     }
283 }
284 
285 /// Ops for running vhost-user over a stream (i.e. regular protocol).
286 pub(super) struct VhostUserRegularOps;
287 
288 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>289     pub fn set_mem_table(
290         contexts: &[VhostUserMemoryRegion],
291         files: Vec<File>,
292     ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
293         if files.len() != contexts.len() {
294             return Err(VhostError::InvalidParam);
295         }
296 
297         let mut regions = Vec::with_capacity(files.len());
298         for (region, file) in contexts.iter().zip(files.into_iter()) {
299             let region = MemoryRegion::new_from_shm(
300                 region.memory_size,
301                 GuestAddress(region.guest_phys_addr),
302                 region.mmap_offset,
303                 Arc::new(
304                     SharedMemory::from_safe_descriptor(
305                         SafeDescriptor::from(file),
306                         region.memory_size,
307                     )
308                     .unwrap(),
309                 ),
310             )
311             .map_err(|e| {
312                 error!("failed to create a memory region: {}", e);
313                 VhostError::InvalidOperation
314             })?;
315             regions.push(region);
316         }
317         let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
318             error!("failed to create guest memory: {}", e);
319             VhostError::InvalidOperation
320         })?;
321 
322         let vmm_maps = contexts
323             .iter()
324             .map(|region| MappingInfo {
325                 vmm_addr: region.user_addr,
326                 guest_phys: region.guest_phys_addr,
327                 size: region.memory_size,
328             })
329             .collect();
330         Ok((guest_mem, vmm_maps))
331     }
332 
set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event>333     pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
334         let file = file.ok_or(VhostError::InvalidParam)?;
335         // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
336         // values via `next_val()` later.
337         // This is only required (and can only be done) on Unix platforms.
338         #[cfg(any(target_os = "android", target_os = "linux"))]
339         if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
340             error!("failed to remove O_NONBLOCK for kick fd: {}", e);
341             return Err(VhostError::InvalidParam);
342         }
343         Ok(Event::from(SafeDescriptor::from(file)))
344     }
345 
set_vring_call( _index: u8, file: Option<File>, signal_config_change_fn: Box<dyn Fn() + Send + Sync>, ) -> VhostResult<Interrupt>346     pub fn set_vring_call(
347         _index: u8,
348         file: Option<File>,
349         signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
350     ) -> VhostResult<Interrupt> {
351         let file = file.ok_or(VhostError::InvalidParam)?;
352         Ok(Interrupt::new_vhost_user(
353             Event::from(SafeDescriptor::from(file)),
354             signal_config_change_fn,
355         ))
356     }
357 }
358 
359 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
360 pub struct DeviceRequestHandler<T: VhostUserDevice> {
361     vrings: Vec<Vring>,
362     owned: bool,
363     vmm_maps: Option<Vec<MappingInfo>>,
364     mem: Option<GuestMemory>,
365     backend: T,
366     backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
367 }
368 
369 #[derive(Serialize, Deserialize)]
370 pub struct DeviceRequestHandlerSnapshot {
371     vrings: Vec<VringSnapshot>,
372     backend: Vec<u8>,
373 }
374 
375 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
376     /// Creates a vhost-user handler instance for `backend`.
new(backend: T) -> Self377     pub(crate) fn new(backend: T) -> Self {
378         let mut vrings = Vec::with_capacity(backend.max_queue_num());
379         for _ in 0..backend.max_queue_num() {
380             vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
381         }
382 
383         DeviceRequestHandler {
384             vrings,
385             owned: false,
386             vmm_maps: None,
387             mem: None,
388             backend,
389             backend_req_connection: Arc::new(Mutex::new(
390                 VhostBackendReqConnectionState::NoConnection,
391             )),
392         }
393     }
394 }
395 
396 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T397     fn as_ref(&self) -> &T {
398         &self.backend
399     }
400 }
401 
402 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T403     fn as_mut(&mut self) -> &mut T {
404         &mut self.backend
405     }
406 }
407 
408 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>409     fn set_owner(&mut self) -> VhostResult<()> {
410         if self.owned {
411             return Err(VhostError::InvalidOperation);
412         }
413         self.owned = true;
414         Ok(())
415     }
416 
reset_owner(&mut self) -> VhostResult<()>417     fn reset_owner(&mut self) -> VhostResult<()> {
418         self.owned = false;
419         self.backend.reset();
420         Ok(())
421     }
422 
get_features(&mut self) -> VhostResult<u64>423     fn get_features(&mut self) -> VhostResult<u64> {
424         let features = self.backend.features();
425         Ok(features)
426     }
427 
set_features(&mut self, features: u64) -> VhostResult<()>428     fn set_features(&mut self, features: u64) -> VhostResult<()> {
429         if !self.owned {
430             return Err(VhostError::InvalidOperation);
431         }
432 
433         if (features & !(self.backend.features())) != 0 {
434             return Err(VhostError::InvalidParam);
435         }
436 
437         if let Err(e) = self.backend.ack_features(features) {
438             error!("failed to acknowledge features 0x{:x}: {}", features, e);
439             return Err(VhostError::InvalidOperation);
440         }
441 
442         // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
443         // enabled state.
444         // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
445         // disabled state.
446         // Client must not pass data to/from the backend until ring is enabled by
447         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
448         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
449         let acked_features = self.backend.acked_features();
450         let vring_enabled = acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
451         for v in &mut self.vrings {
452             v.enabled = vring_enabled;
453         }
454 
455         Ok(())
456     }
457 
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>458     fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
459         Ok(self.backend.protocol_features())
460     }
461 
set_protocol_features(&mut self, features: u64) -> VhostResult<()>462     fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
463         if let Err(e) = self.backend.ack_protocol_features(features) {
464             error!("failed to set protocol features 0x{:x}: {}", features, e);
465             return Err(VhostError::InvalidOperation);
466         }
467         Ok(())
468     }
469 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>470     fn set_mem_table(
471         &mut self,
472         contexts: &[VhostUserMemoryRegion],
473         files: Vec<File>,
474     ) -> VhostResult<()> {
475         let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
476         self.mem = Some(guest_mem);
477         self.vmm_maps = Some(vmm_maps);
478         Ok(())
479     }
480 
get_queue_num(&mut self) -> VhostResult<u64>481     fn get_queue_num(&mut self) -> VhostResult<u64> {
482         Ok(self.vrings.len() as u64)
483     }
484 
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>485     fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
486         if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
487             return Err(VhostError::InvalidParam);
488         }
489         self.vrings[index as usize].queue.set_size(num as u16);
490 
491         Ok(())
492     }
493 
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>494     fn set_vring_addr(
495         &mut self,
496         index: u32,
497         _flags: VhostUserVringAddrFlags,
498         descriptor: u64,
499         used: u64,
500         available: u64,
501         _log: u64,
502     ) -> VhostResult<()> {
503         if index as usize >= self.vrings.len() {
504             return Err(VhostError::InvalidParam);
505         }
506 
507         let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
508         let vring = &mut self.vrings[index as usize];
509         vring
510             .queue
511             .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
512         vring
513             .queue
514             .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
515         vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
516 
517         Ok(())
518     }
519 
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>520     fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
521         if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
522             return Err(VhostError::InvalidParam);
523         }
524 
525         let vring = &mut self.vrings[index as usize];
526         vring.queue.set_next_avail(Wrapping(base as u16));
527         vring.queue.set_next_used(Wrapping(base as u16));
528 
529         Ok(())
530     }
531 
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>532     fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
533         let vring = self
534             .vrings
535             .get_mut(index as usize)
536             .ok_or(VhostError::InvalidParam)?;
537 
538         // Quotation from vhost-user spec:
539         // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
540         // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
541         // means it is started and should be stopped.
542         if vring.queue.ready() {
543             if let Err(e) = self.backend.stop_queue(index as usize) {
544                 error!("Failed to stop queue in get_vring_base: {:#}", e);
545             }
546 
547             vring.reset();
548         }
549 
550         Ok(VhostUserVringState::new(
551             index,
552             vring.queue.next_avail().0 as u32,
553         ))
554     }
555 
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>556     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
557         if index as usize >= self.vrings.len() {
558             return Err(VhostError::InvalidParam);
559         }
560 
561         let vring = &mut self.vrings[index as usize];
562         if vring.queue.ready() {
563             error!("kick fd cannot replaced after queue is started");
564             return Err(VhostError::InvalidOperation);
565         }
566 
567         let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
568 
569         // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
570         vring.queue.ack_features(self.backend.acked_features());
571         vring.queue.set_ready(true);
572 
573         let mem = self
574             .mem
575             .as_ref()
576             .cloned()
577             .ok_or(VhostError::InvalidOperation)?;
578 
579         let queue = match vring.queue.activate(&mem, kick_evt) {
580             Ok(queue) => queue,
581             Err(e) => {
582                 error!("failed to activate vring: {:#}", e);
583                 return Err(VhostError::BackendInternalError);
584             }
585         };
586 
587         let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
588 
589         if let Err(e) = self
590             .backend
591             .start_queue(index as usize, queue, mem, doorbell)
592         {
593             error!("Failed to start queue {}: {}", index, e);
594             return Err(VhostError::BackendInternalError);
595         }
596 
597         Ok(())
598     }
599 
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>600     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
601         if index as usize >= self.vrings.len() {
602             return Err(VhostError::InvalidParam);
603         }
604 
605         let backend_req_conn = self.backend_req_connection.clone();
606         let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
607             VhostBackendReqConnectionState::Connected(frontend) => {
608                 if let Err(e) = frontend.send_config_changed() {
609                     error!("Failed to notify config change: {:#}", e);
610                 }
611             }
612             VhostBackendReqConnectionState::NoConnection => {
613                 error!("No Backend request connection found");
614             }
615         });
616 
617         let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
618         self.vrings[index as usize].doorbell = Some(doorbell);
619         Ok(())
620     }
621 
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>622     fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
623         // TODO
624         Ok(())
625     }
626 
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>627     fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
628         if index as usize >= self.vrings.len() {
629             return Err(VhostError::InvalidParam);
630         }
631 
632         // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
633         // has been negotiated.
634         if self.backend.acked_features() & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
635             return Err(VhostError::InvalidOperation);
636         }
637 
638         // Backend must not pass data to/from the ring until ring is enabled by
639         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
640         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
641         self.vrings[index as usize].enabled = enable;
642 
643         Ok(())
644     }
645 
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>646     fn get_config(
647         &mut self,
648         offset: u32,
649         size: u32,
650         _flags: VhostUserConfigFlags,
651     ) -> VhostResult<Vec<u8>> {
652         let mut data = vec![0; size as usize];
653         self.backend.read_config(u64::from(offset), &mut data);
654         Ok(data)
655     }
656 
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>657     fn set_config(
658         &mut self,
659         offset: u32,
660         buf: &[u8],
661         _flags: VhostUserConfigFlags,
662     ) -> VhostResult<()> {
663         self.backend.write_config(u64::from(offset), buf);
664         Ok(())
665     }
666 
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)667     fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
668         let conn = Arc::new(VhostBackendReqConnection::new(
669             FrontendClient::new(ep),
670             self.backend.get_shared_memory_region().map(|r| r.id),
671         ));
672 
673         {
674             let mut backend_req_conn = self.backend_req_connection.lock();
675             if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
676                 warn!("Backend Request Connection already established. Overwriting");
677             }
678             *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
679         }
680 
681         self.backend.set_backend_req_connection(conn);
682     }
683 
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>684     fn get_inflight_fd(
685         &mut self,
686         _inflight: &VhostUserInflight,
687     ) -> VhostResult<(VhostUserInflight, File)> {
688         unimplemented!("get_inflight_fd");
689     }
690 
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>691     fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
692         unimplemented!("set_inflight_fd");
693     }
694 
get_max_mem_slots(&mut self) -> VhostResult<u64>695     fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
696         //TODO
697         Ok(0)
698     }
699 
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>700     fn add_mem_region(
701         &mut self,
702         _region: &VhostUserSingleMemoryRegion,
703         _fd: File,
704     ) -> VhostResult<()> {
705         //TODO
706         Ok(())
707     }
708 
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>709     fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
710         //TODO
711         Ok(())
712     }
713 
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>714     fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
715         Ok(if let Some(r) = self.backend.get_shared_memory_region() {
716             vec![VhostSharedMemoryRegion::new(r.id, r.length)]
717         } else {
718             Vec::new()
719         })
720     }
721 
sleep(&mut self) -> VhostResult<()>722     fn sleep(&mut self) -> VhostResult<()> {
723         for (index, vring) in self
724             .vrings
725             .iter_mut()
726             .enumerate()
727             .filter(|(_index, vring)| vring.queue.ready())
728         {
729             match self.backend.stop_queue(index) {
730                 Ok(queue) => vring.paused_queue = Some(queue),
731                 Err(e) => {
732                     error!("failed to stop queue index {}: {:#}", index, e);
733                     return Err(VhostError::StopQueueError(e));
734                 }
735             }
736         }
737         self.backend
738             .stop_non_queue_workers()
739             .map_err(VhostError::SleepError)
740     }
741 
wake(&mut self) -> VhostResult<()>742     fn wake(&mut self) -> VhostResult<()> {
743         for (index, vring) in self.vrings.iter_mut().enumerate() {
744             if let Some(queue) = vring.paused_queue.take() {
745                 let mem = self.mem.clone().ok_or(VhostError::BackendInternalError)?;
746                 let doorbell = vring.doorbell.clone().expect("Failed to clone doorbell");
747 
748                 if let Err(e) = self.backend.start_queue(index, queue, mem, doorbell) {
749                     error!("Failed to start queue {}: {}", index, e);
750                     return Err(VhostError::BackendInternalError);
751                 }
752             }
753         }
754         Ok(())
755     }
756 
snapshot(&mut self) -> VhostResult<Vec<u8>>757     fn snapshot(&mut self) -> VhostResult<Vec<u8>> {
758         match serde_json::to_vec(&DeviceRequestHandlerSnapshot {
759             vrings: self
760                 .vrings
761                 .iter()
762                 .map(|vring| vring.snapshot())
763                 .collect::<anyhow::Result<Vec<VringSnapshot>>>()
764                 .map_err(VhostError::SnapshotError)?,
765             backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
766         }) {
767             Ok(serialized_json) => Ok(serialized_json),
768             Err(e) => {
769                 error!("Failed to serialize DeviceRequestHandlerSnapshot: {}", e);
770                 Err(VhostError::SerializationFailed)
771             }
772         }
773     }
774 
restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> VhostResult<()>775     fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> VhostResult<()> {
776         let device_request_handler_snapshot: DeviceRequestHandlerSnapshot =
777             serde_json::from_slice(data_bytes).map_err(|e| {
778                 error!("Failed to deserialize DeviceRequestHandlerSnapshot: {}", e);
779                 VhostError::DeserializationFailed
780             })?;
781 
782         let mem = self.mem.as_ref().ok_or(VhostError::InvalidOperation)?;
783 
784         let snapshotted_vrings = device_request_handler_snapshot.vrings;
785         assert_eq!(snapshotted_vrings.len(), self.vrings.len());
786 
787         let mut queue_evts_iter = if queue_evts.is_empty() {
788             None
789         } else {
790             Some(queue_evts.into_iter())
791         };
792 
793         for (index, (vring, snapshotted_vring)) in self
794             .vrings
795             .iter_mut()
796             .zip(snapshotted_vrings.into_iter())
797             .enumerate()
798         {
799             let queue_evt = if let Some(queue_evts_iter) = &mut queue_evts_iter {
800                 // TODO(b/288596005): It is assumed that the index of `queue_evts` should map to the
801                 // index of `self.vrings`. However, this assumption may break in the future, so a
802                 // Map of indexes to queue_evt should be used to support sparse activated queues.
803                 let queue_evt_file = queue_evts_iter
804                     .next()
805                     .ok_or(VhostError::VringIndexNotFound(index))?;
806                 Some(VhostUserRegularOps::set_vring_kick(
807                     index as u8,
808                     Some(queue_evt_file),
809                 )?)
810             } else {
811                 None
812             };
813 
814             vring
815                 .restore(snapshotted_vring, mem, queue_evt)
816                 .map_err(VhostError::RestoreError)?;
817         }
818 
819         self.backend
820             .restore(device_request_handler_snapshot.backend)
821             .map_err(VhostError::RestoreError)?;
822 
823         Ok(())
824     }
825 }
826 
827 /// Indicates the state of backend request connection
828 pub enum VhostBackendReqConnectionState {
829     /// A backend request connection (`VhostBackendReqConnection`) is established
830     Connected(Arc<VhostBackendReqConnection>),
831     /// No backend request connection has been established yet
832     NoConnection,
833 }
834 
835 /// Keeps track of Vhost user backend request connection.
836 pub struct VhostBackendReqConnection {
837     conn: Arc<Mutex<FrontendClient>>,
838     shmem_info: Mutex<Option<ShmemInfo>>,
839 }
840 
841 #[derive(Clone)]
842 struct ShmemInfo {
843     shmid: u8,
844     mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
845 }
846 
847 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self848     pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
849         let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
850             shmid,
851             mapped_regions: BTreeMap::new(),
852         }));
853         Self {
854             conn: Arc::new(Mutex::new(conn)),
855             shmem_info,
856         }
857     }
858 
859     /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>860     pub fn send_config_changed(&self) -> anyhow::Result<()> {
861         self.conn
862             .lock()
863             .handle_config_change()
864             .context("Could not send config change message")?;
865         Ok(())
866     }
867 
868     /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>869     pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
870         let shmem_info = self
871             .shmem_info
872             .lock()
873             .take()
874             .context("could not take shared memory mapper information")?;
875 
876         Ok(Box::new(VhostShmemMapper {
877             conn: self.conn.clone(),
878             shmem_info,
879         }))
880     }
881 }
882 
883 struct VhostShmemMapper {
884     conn: Arc<Mutex<FrontendClient>>,
885     shmem_info: ShmemInfo,
886 }
887 
888 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>889     fn add_mapping(
890         &mut self,
891         source: VmMemorySource,
892         offset: u64,
893         prot: Protection,
894         _cache: MemCacheType,
895     ) -> anyhow::Result<()> {
896         let size = match source {
897             VmMemorySource::Vulkan {
898                 descriptor,
899                 handle_type,
900                 memory_idx,
901                 device_uuid,
902                 driver_uuid,
903                 size,
904             } => {
905                 let msg = VhostUserGpuMapMsg::new(
906                     self.shmem_info.shmid,
907                     offset,
908                     size,
909                     memory_idx,
910                     handle_type,
911                     device_uuid,
912                     driver_uuid,
913                 );
914                 self.conn
915                     .lock()
916                     .gpu_map(&msg, &descriptor)
917                     .context("failed to map memory")?;
918                 size
919             }
920             VmMemorySource::ExternalMapping { ptr, size } => {
921                 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
922                 self.conn
923                     .lock()
924                     .external_map(&msg)
925                     .context("failed to map memory")?;
926                 size
927             }
928             source => {
929                 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
930                 // on the aliased `source` above.
931                 let (descriptor, fd_offset, size) = match source {
932                     VmMemorySource::Descriptor {
933                         descriptor,
934                         offset,
935                         size,
936                     } => (descriptor, offset, size),
937                     VmMemorySource::SharedMemory(shmem) => {
938                         let size = shmem.size();
939                         let descriptor =
940                             // SAFETY:
941                             // Safe because we own shmem.
942                             unsafe {
943                                 SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor())
944                             };
945                         (descriptor, 0, size)
946                     }
947                     _ => bail!("unsupported source"),
948                 };
949                 let flags = VhostUserShmemMapMsgFlags::from(prot);
950                 let msg = VhostUserShmemMapMsg::new(
951                     self.shmem_info.shmid,
952                     offset,
953                     fd_offset,
954                     size,
955                     flags,
956                 );
957                 self.conn
958                     .lock()
959                     .shmem_map(&msg, &descriptor)
960                     .context("failed to map memory")?;
961                 size
962             }
963         };
964 
965         self.shmem_info.mapped_regions.insert(offset, size);
966         Ok(())
967     }
968 
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>969     fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
970         let size = self
971             .shmem_info
972             .mapped_regions
973             .remove(&offset)
974             .context("unknown offset")?;
975         let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
976         self.conn
977             .lock()
978             .shmem_unmap(&msg)
979             .context("failed to map memory")
980             .map(|_| ())
981     }
982 }
983 
984 pub(crate) struct WorkerState<T, U> {
985     pub(crate) queue_task: TaskHandle<U>,
986     pub(crate) queue: T,
987 }
988 
989 /// Errors for device operations
990 #[derive(Debug, ThisError)]
991 pub enum Error {
992     #[error("worker not found when stopping queue")]
993     WorkerNotFound,
994 }
995 
996 #[cfg(test)]
997 mod tests {
998     use std::sync::mpsc::channel;
999     use std::sync::Barrier;
1000 
1001     use anyhow::anyhow;
1002     use anyhow::bail;
1003     use base::Event;
1004     use vmm_vhost::BackendServer;
1005     use vmm_vhost::FrontendReq;
1006     use zerocopy::AsBytes;
1007     use zerocopy::FromBytes;
1008     use zerocopy::FromZeroes;
1009 
1010     use super::sys::test_helpers;
1011     use super::*;
1012     use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1013     use crate::virtio::DeviceType;
1014     use crate::virtio::VirtioDevice;
1015 
1016     #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
1017     #[repr(C, packed(4))]
1018     struct FakeConfig {
1019         x: u32,
1020         y: u64,
1021     }
1022 
1023     const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1024 
1025     pub(super) struct FakeBackend {
1026         avail_features: u64,
1027         acked_features: u64,
1028         acked_protocol_features: VhostUserProtocolFeatures,
1029         active_queues: Vec<Option<Queue>>,
1030         allow_backend_req: bool,
1031         backend_conn: Option<Arc<VhostBackendReqConnection>>,
1032     }
1033 
1034     impl FakeBackend {
1035         const MAX_QUEUE_NUM: usize = 16;
1036 
new() -> Self1037         pub(super) fn new() -> Self {
1038             let mut active_queues = Vec::new();
1039             active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1040             Self {
1041                 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1042                 acked_features: 0,
1043                 acked_protocol_features: VhostUserProtocolFeatures::empty(),
1044                 active_queues,
1045                 allow_backend_req: false,
1046                 backend_conn: None,
1047             }
1048         }
1049     }
1050 
1051     impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1052         fn max_queue_num(&self) -> usize {
1053             Self::MAX_QUEUE_NUM
1054         }
1055 
features(&self) -> u641056         fn features(&self) -> u64 {
1057             self.avail_features
1058         }
1059 
ack_features(&mut self, value: u64) -> anyhow::Result<()>1060         fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1061             let unrequested_features = value & !self.avail_features;
1062             if unrequested_features != 0 {
1063                 bail!(
1064                     "invalid protocol features are given: 0x{:x}",
1065                     unrequested_features
1066                 );
1067             }
1068             self.acked_features |= value;
1069             Ok(())
1070         }
1071 
acked_features(&self) -> u641072         fn acked_features(&self) -> u64 {
1073             self.acked_features
1074         }
1075 
protocol_features(&self) -> VhostUserProtocolFeatures1076         fn protocol_features(&self) -> VhostUserProtocolFeatures {
1077             let mut features = VhostUserProtocolFeatures::CONFIG;
1078             if self.allow_backend_req {
1079                 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1080             }
1081             features
1082         }
1083 
ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()>1084         fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> {
1085             let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
1086                 "invalid protocol features are given: 0x{:x}",
1087                 features
1088             ))?;
1089             let supported = self.protocol_features();
1090             self.acked_protocol_features = features & supported;
1091             Ok(())
1092         }
1093 
acked_protocol_features(&self) -> u641094         fn acked_protocol_features(&self) -> u64 {
1095             self.acked_protocol_features.bits()
1096         }
1097 
read_config(&self, offset: u64, dst: &mut [u8])1098         fn read_config(&self, offset: u64, dst: &mut [u8]) {
1099             dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1100         }
1101 
reset(&mut self)1102         fn reset(&mut self) {}
1103 
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, _doorbell: Interrupt, ) -> anyhow::Result<()>1104         fn start_queue(
1105             &mut self,
1106             idx: usize,
1107             queue: Queue,
1108             _mem: GuestMemory,
1109             _doorbell: Interrupt,
1110         ) -> anyhow::Result<()> {
1111             self.active_queues[idx] = Some(queue);
1112             Ok(())
1113         }
1114 
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1115         fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1116             Ok(self.active_queues[idx]
1117                 .take()
1118                 .ok_or(Error::WorkerNotFound)?)
1119         }
1120 
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1121         fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1122             self.backend_conn = Some(conn);
1123         }
1124     }
1125 
1126     #[test]
test_vhost_user_activate()1127     fn test_vhost_user_activate() {
1128         test_vhost_user_activate_parameterized(false);
1129     }
1130 
1131     #[test]
1132     #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_activate_with_backend_req()1133     fn test_vhost_user_activate_with_backend_req() {
1134         test_vhost_user_activate_parameterized(true);
1135     }
1136 
test_vhost_user_activate_parameterized(allow_backend_req: bool)1137     fn test_vhost_user_activate_parameterized(allow_backend_req: bool) {
1138         const QUEUES_NUM: usize = 2;
1139 
1140         let (dev, vmm) = test_helpers::setup();
1141 
1142         let vmm_bar = Arc::new(Barrier::new(2));
1143         let dev_bar = vmm_bar.clone();
1144 
1145         let (ready_tx, ready_rx) = channel();
1146         let (shutdown_tx, shutdown_rx) = channel();
1147 
1148         std::thread::spawn(move || {
1149             // VMM side
1150             ready_rx.recv().unwrap(); // Ensure the device is ready.
1151 
1152             let connection = test_helpers::connect(vmm);
1153 
1154             let mut vmm_device =
1155                 VhostUserFrontend::new(DeviceType::Console, 0, connection, None, None).unwrap();
1156 
1157             println!("read_config");
1158             let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1159             vmm_device.read_config(0, &mut buf);
1160             // Check if the obtained config data is correct.
1161             let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1162             assert_eq!(config, FAKE_CONFIG_DATA);
1163 
1164             let activate = |vmm_device: &mut VhostUserFrontend| {
1165                 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1166                 let interrupt = Interrupt::new_for_test_with_msix();
1167 
1168                 let mut queues = BTreeMap::new();
1169                 for idx in 0..QUEUES_NUM {
1170                     let mut queue = QueueConfig::new(0x10, 0);
1171                     queue.set_ready(true);
1172                     let queue = queue
1173                         .activate(&mem, Event::new().unwrap())
1174                         .expect("QueueConfig::activate");
1175                     queues.insert(idx, queue);
1176                 }
1177 
1178                 println!("activate");
1179                 vmm_device
1180                     .activate(mem.clone(), interrupt.clone(), queues)
1181                     .unwrap();
1182             };
1183 
1184             activate(&mut vmm_device);
1185 
1186             println!("reset");
1187             let reset_result = vmm_device.reset();
1188             assert!(
1189                 reset_result.is_ok(),
1190                 "reset failed: {:#}",
1191                 reset_result.unwrap_err()
1192             );
1193 
1194             activate(&mut vmm_device);
1195 
1196             println!("virtio_sleep");
1197             vmm_device.virtio_sleep().unwrap();
1198 
1199             println!("virtio_wake");
1200             vmm_device.virtio_wake(None).unwrap();
1201 
1202             println!("wait for shutdown signal");
1203             shutdown_rx.recv().unwrap();
1204 
1205             // The VMM side is supposed to stop before the device side.
1206             println!("drop");
1207             drop(vmm_device);
1208 
1209             vmm_bar.wait();
1210         });
1211 
1212         // Device side
1213         let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1214         handler.as_mut().allow_backend_req = allow_backend_req;
1215 
1216         // Notify listener is ready.
1217         ready_tx.send(()).unwrap();
1218 
1219         let mut req_handler = test_helpers::listen(dev, handler);
1220 
1221         // VhostUserFrontend::new()
1222         handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1223         handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1224         handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1225         handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1226         handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1227         if allow_backend_req {
1228             handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1229         }
1230 
1231         // VhostUserFrontend::read_config()
1232         handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1233 
1234         // VhostUserFrontend::activate()
1235         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1236         for _ in 0..QUEUES_NUM {
1237             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1238             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1239             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1240             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1241             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1242             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1243         }
1244 
1245         // VhostUserFrontend::reset()
1246         for _ in 0..QUEUES_NUM {
1247             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1248             handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1249         }
1250 
1251         // VhostUserFrontend::activate()
1252         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1253         for _ in 0..QUEUES_NUM {
1254             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1255             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1256             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1257             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1258             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1259             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1260         }
1261 
1262         if allow_backend_req {
1263             // Make sure the connection still works even after reset/reactivate.
1264             req_handler
1265                 .as_ref()
1266                 .as_ref()
1267                 .backend_conn
1268                 .as_ref()
1269                 .expect("backend_conn missing")
1270                 .send_config_changed()
1271                 .expect("send_config_changed failed");
1272         }
1273 
1274         // VhostUserFrontend::virtio_sleep()
1275         handle_request(&mut req_handler, FrontendReq::SLEEP).unwrap();
1276 
1277         // VhostUserFrontend::virtio_wake()
1278         handle_request(&mut req_handler, FrontendReq::WAKE).unwrap();
1279 
1280         if allow_backend_req {
1281             // Make sure the connection still works even after sleep/wake.
1282             req_handler
1283                 .as_ref()
1284                 .as_ref()
1285                 .backend_conn
1286                 .as_ref()
1287                 .expect("backend_conn missing")
1288                 .send_config_changed()
1289                 .expect("send_config_changed failed");
1290         }
1291 
1292         // Ask the client to shutdown, then wait to it to finish.
1293         shutdown_tx.send(()).unwrap();
1294         dev_bar.wait();
1295 
1296         // Verify recv_header fails with `ClientExit` after the client has disconnected.
1297         match req_handler.recv_header() {
1298             Err(VhostError::ClientExit) => (),
1299             r => panic!("expected Err(ClientExit) but got {:?}", r),
1300         }
1301     }
1302 
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1303     fn handle_request<S: vmm_vhost::Backend>(
1304         handler: &mut BackendServer<S>,
1305         expected_message_type: FrontendReq,
1306     ) -> Result<(), VhostError> {
1307         let (hdr, files) = handler.recv_header()?;
1308         assert_eq!(hdr.get_code(), Ok(expected_message_type));
1309         handler.process_message(hdr, files)
1310     }
1311 }
1312