• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //!   /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //!   /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //!   let backend = MyBackend { /* initialize fields */ };
28 //!   let handler = DeviceRequestHandler::new(backend);
29 //!   let socket = std::path::Path("/path/to/socket");
30 //!   let ex = cros_async::Executor::new()?;
31 //!
32 //!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //!     eprintln!("error happened: {}", e);
34 //!   }
35 //!   Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46 
47 pub(super) mod sys;
48 
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56 
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::trace;
63 use base::warn;
64 use base::Event;
65 use base::Protection;
66 use base::SafeDescriptor;
67 use base::SharedMemory;
68 use base::WorkerThread;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use snapshot::AnySnapshot;
74 use sync::Mutex;
75 use thiserror::Error as ThisError;
76 use vm_control::VmMemorySource;
77 use vm_memory::GuestAddress;
78 use vm_memory::GuestMemory;
79 use vm_memory::MemoryRegion;
80 use vmm_vhost::message::VhostSharedMemoryRegion;
81 use vmm_vhost::message::VhostUserConfigFlags;
82 use vmm_vhost::message::VhostUserExternalMapMsg;
83 use vmm_vhost::message::VhostUserGpuMapMsg;
84 use vmm_vhost::message::VhostUserInflight;
85 use vmm_vhost::message::VhostUserMemoryRegion;
86 use vmm_vhost::message::VhostUserMigrationPhase;
87 use vmm_vhost::message::VhostUserProtocolFeatures;
88 use vmm_vhost::message::VhostUserShmemMapMsg;
89 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
90 use vmm_vhost::message::VhostUserShmemUnmapMsg;
91 use vmm_vhost::message::VhostUserSingleMemoryRegion;
92 use vmm_vhost::message::VhostUserTransferDirection;
93 use vmm_vhost::message::VhostUserVringAddrFlags;
94 use vmm_vhost::message::VhostUserVringState;
95 use vmm_vhost::BackendReq;
96 use vmm_vhost::Connection;
97 use vmm_vhost::Error as VhostError;
98 use vmm_vhost::Frontend;
99 use vmm_vhost::FrontendClient;
100 use vmm_vhost::Result as VhostResult;
101 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
102 
103 use crate::virtio::Interrupt;
104 use crate::virtio::Queue;
105 use crate::virtio::QueueConfig;
106 use crate::virtio::SharedMemoryMapper;
107 use crate::virtio::SharedMemoryRegion;
108 
109 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
110 /// used to translate messages from the vmm to guest offsets.
111 #[derive(Default)]
112 pub struct MappingInfo {
113     pub vmm_addr: u64,
114     pub guest_phys: u64,
115     pub size: u64,
116 }
117 
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>118 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
119     for map in maps {
120         if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
121             return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
122         }
123     }
124     Err(VhostError::InvalidMessage)
125 }
126 
127 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
128 ///
129 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
130 /// is designed to follow crosvm conventions for implementing devices.
131 pub trait VhostUserDevice {
132     /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize133     fn max_queue_num(&self) -> usize;
134 
135     /// The set of feature bits that this backend supports.
features(&self) -> u64136     fn features(&self) -> u64;
137 
138     /// Acknowledges that this set of features should be enabled.
139     ///
140     /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
141     /// framework will manage generic vhost and vring features.
142     ///
143     /// `DeviceRequestHandler` checks for valid features before calling this function, so the
144     /// features in `value` will always be a subset of those advertised by `features()`.
ack_features(&mut self, _value: u64) -> anyhow::Result<()>145     fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
146         Ok(())
147     }
148 
149     /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures150     fn protocol_features(&self) -> VhostUserProtocolFeatures;
151 
152     /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])153     fn read_config(&self, offset: u64, dst: &mut [u8]);
154 
155     /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])156     fn write_config(&self, _offset: u64, _data: &[u8]) {}
157 
158     /// Indicates that the backend should start processing requests for virtio queue number `idx`.
159     /// This method must not block the current thread so device backends should either spawn an
160     /// async task or another thread to handle messages from the Queue.
start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>161     fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
162 
163     /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
164     /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
165     /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>166     fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
167 
168     /// Resets the vhost-user backend.
reset(&mut self)169     fn reset(&mut self);
170 
171     /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>172     fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
173         None
174     }
175 
176     /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
177     /// handling.
178     ///
179     /// The backend is given an `Arc` instead of full ownership so that the framework can also use
180     /// the connection.
181     ///
182     /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
183     /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)184     fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {}
185 
186     /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
187     /// requirements.
188     ///
189     /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
190     ///
191     /// Called after a `stop_queue` call if there are no running queues left. Also called soon
192     /// after device creation to ensure the device is acting suspended immediately on construction.
193     ///
194     /// The next `start_queue` call implicitly exits the "suspend device state".
195     ///
196     /// * Ok(())    => device successfully suspended
197     /// * Err(_)    => unrecoverable error
enter_suspended_state(&mut self) -> anyhow::Result<()>198     fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
199 
200     /// Snapshot device and return serialized state.
snapshot(&mut self) -> anyhow::Result<AnySnapshot>201     fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
202 
203     /// Restore device state from a snapshot.
restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>204     fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
205 }
206 
207 /// A virtio ring entry.
208 struct Vring {
209     // The queue config. This doesn't get mutated by the queue workers.
210     queue: QueueConfig,
211     doorbell: Option<Interrupt>,
212     enabled: bool,
213 }
214 
215 impl Vring {
new(max_size: u16, features: u64) -> Self216     fn new(max_size: u16, features: u64) -> Self {
217         Self {
218             queue: QueueConfig::new(max_size, features),
219             doorbell: None,
220             enabled: false,
221         }
222     }
223 
reset(&mut self)224     fn reset(&mut self) {
225         self.queue.reset();
226         self.doorbell = None;
227         self.enabled = false;
228     }
229 }
230 
231 /// Ops for running vhost-user over a stream (i.e. regular protocol).
232 pub(super) struct VhostUserRegularOps;
233 
234 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>235     pub fn set_mem_table(
236         contexts: &[VhostUserMemoryRegion],
237         files: Vec<File>,
238     ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
239         if files.len() != contexts.len() {
240             return Err(VhostError::InvalidParam(
241                 "number of files & contexts was not equal",
242             ));
243         }
244 
245         let mut regions = Vec::with_capacity(files.len());
246         for (region, file) in contexts.iter().zip(files.into_iter()) {
247             let region = MemoryRegion::new_from_shm(
248                 region.memory_size,
249                 GuestAddress(region.guest_phys_addr),
250                 region.mmap_offset,
251                 Arc::new(
252                     SharedMemory::from_safe_descriptor(
253                         SafeDescriptor::from(file),
254                         region.memory_size,
255                     )
256                     .unwrap(),
257                 ),
258             )
259             .map_err(|e| {
260                 error!("failed to create a memory region: {}", e);
261                 VhostError::InvalidOperation
262             })?;
263             regions.push(region);
264         }
265         let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
266             error!("failed to create guest memory: {}", e);
267             VhostError::InvalidOperation
268         })?;
269 
270         let vmm_maps = contexts
271             .iter()
272             .map(|region| MappingInfo {
273                 vmm_addr: region.user_addr,
274                 guest_phys: region.guest_phys_addr,
275                 size: region.memory_size,
276             })
277             .collect();
278         Ok((guest_mem, vmm_maps))
279     }
280 }
281 
282 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
283 pub struct DeviceRequestHandler<T: VhostUserDevice> {
284     vrings: Vec<Vring>,
285     owned: bool,
286     vmm_maps: Option<Vec<MappingInfo>>,
287     mem: Option<GuestMemory>,
288     acked_features: u64,
289     acked_protocol_features: VhostUserProtocolFeatures,
290     backend: T,
291     backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
292     // Thread processing active device state FD.
293     device_state_thread: Option<DeviceStateThread>,
294 }
295 
296 enum DeviceStateThread {
297     Save(WorkerThread<serde_json::Result<()>>),
298     Load(WorkerThread<serde_json::Result<DeviceRequestHandlerSnapshot>>),
299 }
300 
301 #[derive(Serialize, Deserialize)]
302 pub struct DeviceRequestHandlerSnapshot {
303     acked_features: u64,
304     acked_protocol_features: u64,
305     backend: AnySnapshot,
306 }
307 
308 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
309     /// Creates a vhost-user handler instance for `backend`.
new(mut backend: T) -> Self310     pub(crate) fn new(mut backend: T) -> Self {
311         let mut vrings = Vec::with_capacity(backend.max_queue_num());
312         for _ in 0..backend.max_queue_num() {
313             vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
314         }
315 
316         // VhostUserDevice implementations must support `enter_suspended_state()`.
317         // Call it on startup to ensure it works and to initialize the device in a suspended state.
318         backend
319             .enter_suspended_state()
320             .expect("enter_suspended_state failed on device init");
321 
322         DeviceRequestHandler {
323             vrings,
324             owned: false,
325             vmm_maps: None,
326             mem: None,
327             acked_features: 0,
328             acked_protocol_features: VhostUserProtocolFeatures::empty(),
329             backend,
330             backend_req_connection: Arc::new(Mutex::new(
331                 VhostBackendReqConnectionState::NoConnection,
332             )),
333             device_state_thread: None,
334         }
335     }
336 
337     /// Check if all queues are stopped.
338     ///
339     /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
all_queues_stopped(&self) -> bool340     fn all_queues_stopped(&self) -> bool {
341         self.vrings.iter().all(|vring| !vring.queue.ready())
342     }
343 }
344 
345 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T346     fn as_ref(&self) -> &T {
347         &self.backend
348     }
349 }
350 
351 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T352     fn as_mut(&mut self) -> &mut T {
353         &mut self.backend
354     }
355 }
356 
357 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>358     fn set_owner(&mut self) -> VhostResult<()> {
359         if self.owned {
360             return Err(VhostError::InvalidOperation);
361         }
362         self.owned = true;
363         Ok(())
364     }
365 
reset_owner(&mut self) -> VhostResult<()>366     fn reset_owner(&mut self) -> VhostResult<()> {
367         self.owned = false;
368         self.acked_features = 0;
369         self.backend.reset();
370         Ok(())
371     }
372 
get_features(&mut self) -> VhostResult<u64>373     fn get_features(&mut self) -> VhostResult<u64> {
374         let features = self.backend.features();
375         Ok(features)
376     }
377 
set_features(&mut self, features: u64) -> VhostResult<()>378     fn set_features(&mut self, features: u64) -> VhostResult<()> {
379         if !self.owned {
380             return Err(VhostError::InvalidOperation);
381         }
382 
383         let unexpected_features = features & !self.backend.features();
384         if unexpected_features != 0 {
385             error!("unexpected set_features {:#x}", unexpected_features);
386             return Err(VhostError::InvalidParam("unexpected set_features"));
387         }
388 
389         if let Err(e) = self.backend.ack_features(features) {
390             error!("failed to acknowledge features 0x{:x}: {}", features, e);
391             return Err(VhostError::InvalidOperation);
392         }
393 
394         self.acked_features |= features;
395 
396         // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
397         // enabled state.
398         // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
399         // disabled state.
400         // Client must not pass data to/from the backend until ring is enabled by
401         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
402         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
403         let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
404         for v in &mut self.vrings {
405             v.enabled = vring_enabled;
406         }
407 
408         Ok(())
409     }
410 
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>411     fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
412         Ok(self.backend.protocol_features())
413     }
414 
set_protocol_features(&mut self, features: u64) -> VhostResult<()>415     fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
416         let features = match VhostUserProtocolFeatures::from_bits(features) {
417             Some(proto_features) => proto_features,
418             None => {
419                 error!(
420                     "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
421                     features
422                 );
423                 return Err(VhostError::InvalidOperation);
424             }
425         };
426         let supported = self.backend.protocol_features();
427         self.acked_protocol_features = features & supported;
428         Ok(())
429     }
430 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>431     fn set_mem_table(
432         &mut self,
433         contexts: &[VhostUserMemoryRegion],
434         files: Vec<File>,
435     ) -> VhostResult<()> {
436         let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
437         self.mem = Some(guest_mem);
438         self.vmm_maps = Some(vmm_maps);
439         Ok(())
440     }
441 
get_queue_num(&mut self) -> VhostResult<u64>442     fn get_queue_num(&mut self) -> VhostResult<u64> {
443         Ok(self.vrings.len() as u64)
444     }
445 
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>446     fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
447         if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
448             return Err(VhostError::InvalidParam(
449                 "set_vring_num: invalid index or num",
450             ));
451         }
452         self.vrings[index as usize].queue.set_size(num as u16);
453 
454         Ok(())
455     }
456 
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>457     fn set_vring_addr(
458         &mut self,
459         index: u32,
460         _flags: VhostUserVringAddrFlags,
461         descriptor: u64,
462         used: u64,
463         available: u64,
464         _log: u64,
465     ) -> VhostResult<()> {
466         if index as usize >= self.vrings.len() {
467             return Err(VhostError::InvalidParam(
468                 "set_vring_addr: index out of range",
469             ));
470         }
471 
472         let vmm_maps = self
473             .vmm_maps
474             .as_ref()
475             .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
476         let vring = &mut self.vrings[index as usize];
477         vring
478             .queue
479             .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
480         vring
481             .queue
482             .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
483         vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
484 
485         Ok(())
486     }
487 
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>488     fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
489         if index as usize >= self.vrings.len() {
490             return Err(VhostError::InvalidParam(
491                 "set_vring_base: index out of range",
492             ));
493         }
494 
495         let vring = &mut self.vrings[index as usize];
496         vring.queue.set_next_avail(Wrapping(base as u16));
497         vring.queue.set_next_used(Wrapping(base as u16));
498 
499         Ok(())
500     }
501 
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>502     fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
503         let vring = self
504             .vrings
505             .get_mut(index as usize)
506             .ok_or(VhostError::InvalidParam(
507                 "get_vring_base: index out of range",
508             ))?;
509 
510         // Quotation from vhost-user spec:
511         // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
512         // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
513         // means it is started and should be stopped.
514         let vring_base = if vring.queue.ready() {
515             let queue = match self.backend.stop_queue(index as usize) {
516                 Ok(q) => q,
517                 Err(e) => {
518                     error!("Failed to stop queue in get_vring_base: {:#}", e);
519                     return Err(VhostError::BackendInternalError);
520                 }
521             };
522 
523             trace!("stopped queue {index}");
524             vring.reset();
525 
526             if self.all_queues_stopped() {
527                 trace!("all queues stopped; entering suspended state");
528                 self.backend
529                     .enter_suspended_state()
530                     .map_err(VhostError::EnterSuspendedState)?;
531             }
532 
533             queue.next_avail_to_process()
534         } else {
535             0
536         };
537 
538         Ok(VhostUserVringState::new(index, vring_base.into()))
539     }
540 
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>541     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
542         if index as usize >= self.vrings.len() {
543             return Err(VhostError::InvalidParam(
544                 "set_vring_kick: index out of range",
545             ));
546         }
547 
548         let vring = &mut self.vrings[index as usize];
549         if vring.queue.ready() {
550             error!("kick fd cannot replaced after queue is started");
551             return Err(VhostError::InvalidOperation);
552         }
553 
554         let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
555 
556         // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
557         // values via `next_val()` later.
558         // This is only required (and can only be done) on Unix platforms.
559         #[cfg(any(target_os = "android", target_os = "linux"))]
560         if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
561             error!("failed to remove O_NONBLOCK for kick fd: {}", e);
562             return Err(VhostError::InvalidParam(
563                 "could not remove O_NONBLOCK from vring_kick",
564             ));
565         }
566 
567         let kick_evt = Event::from(SafeDescriptor::from(file));
568 
569         // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
570         vring.queue.ack_features(self.acked_features);
571         vring.queue.set_ready(true);
572 
573         let mem = self
574             .mem
575             .as_ref()
576             .cloned()
577             .ok_or(VhostError::InvalidOperation)?;
578 
579         let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
580 
581         let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
582             Ok(queue) => queue,
583             Err(e) => {
584                 error!("failed to activate vring: {:#}", e);
585                 return Err(VhostError::BackendInternalError);
586             }
587         };
588 
589         if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
590             error!("Failed to start queue {}: {}", index, e);
591             return Err(VhostError::BackendInternalError);
592         }
593         trace!("started queue {index}");
594 
595         Ok(())
596     }
597 
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>598     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
599         if index as usize >= self.vrings.len() {
600             return Err(VhostError::InvalidParam(
601                 "set_vring_call: index out of range",
602             ));
603         }
604 
605         let backend_req_conn = self.backend_req_connection.clone();
606         let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
607             VhostBackendReqConnectionState::Connected(frontend) => {
608                 if let Err(e) = frontend.send_config_changed() {
609                     error!("Failed to notify config change: {:#}", e);
610                 }
611             }
612             VhostBackendReqConnectionState::NoConnection => {
613                 error!("No Backend request connection found");
614             }
615         });
616 
617         let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
618         self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
619             Event::from(SafeDescriptor::from(file)),
620             signal_config_change_fn,
621         ));
622         Ok(())
623     }
624 
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>625     fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
626         // TODO
627         Ok(())
628     }
629 
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>630     fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
631         if index as usize >= self.vrings.len() {
632             return Err(VhostError::InvalidParam(
633                 "set_vring_enable: index out of range",
634             ));
635         }
636 
637         // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
638         // has been negotiated.
639         if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
640             return Err(VhostError::InvalidOperation);
641         }
642 
643         // Backend must not pass data to/from the ring until ring is enabled by
644         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
645         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
646         self.vrings[index as usize].enabled = enable;
647 
648         Ok(())
649     }
650 
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>651     fn get_config(
652         &mut self,
653         offset: u32,
654         size: u32,
655         _flags: VhostUserConfigFlags,
656     ) -> VhostResult<Vec<u8>> {
657         let mut data = vec![0; size as usize];
658         self.backend.read_config(u64::from(offset), &mut data);
659         Ok(data)
660     }
661 
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>662     fn set_config(
663         &mut self,
664         offset: u32,
665         buf: &[u8],
666         _flags: VhostUserConfigFlags,
667     ) -> VhostResult<()> {
668         self.backend.write_config(u64::from(offset), buf);
669         Ok(())
670     }
671 
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)672     fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
673         let conn = Arc::new(VhostBackendReqConnection::new(
674             FrontendClient::new(ep),
675             self.backend.get_shared_memory_region().map(|r| r.id),
676         ));
677 
678         {
679             let mut backend_req_conn = self.backend_req_connection.lock();
680             if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
681                 warn!("Backend Request Connection already established. Overwriting");
682             }
683             *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
684         }
685 
686         self.backend.set_backend_req_connection(conn);
687     }
688 
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>689     fn get_inflight_fd(
690         &mut self,
691         _inflight: &VhostUserInflight,
692     ) -> VhostResult<(VhostUserInflight, File)> {
693         unimplemented!("get_inflight_fd");
694     }
695 
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>696     fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
697         unimplemented!("set_inflight_fd");
698     }
699 
get_max_mem_slots(&mut self) -> VhostResult<u64>700     fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
701         //TODO
702         Ok(0)
703     }
704 
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>705     fn add_mem_region(
706         &mut self,
707         _region: &VhostUserSingleMemoryRegion,
708         _fd: File,
709     ) -> VhostResult<()> {
710         //TODO
711         Ok(())
712     }
713 
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>714     fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
715         //TODO
716         Ok(())
717     }
718 
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, mut fd: File, ) -> VhostResult<Option<File>>719     fn set_device_state_fd(
720         &mut self,
721         transfer_direction: VhostUserTransferDirection,
722         migration_phase: VhostUserMigrationPhase,
723         mut fd: File,
724     ) -> VhostResult<Option<File>> {
725         if migration_phase != VhostUserMigrationPhase::Stopped {
726             return Err(VhostError::InvalidOperation);
727         }
728         if !self.all_queues_stopped() {
729             return Err(VhostError::InvalidOperation);
730         }
731         if self.device_state_thread.is_some() {
732             error!("must call check_device_state before starting new state transfer");
733             return Err(VhostError::InvalidOperation);
734         }
735         // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
736         // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
737         // handle the serialization and data transfer (the latter which seems necessary to
738         // implement the API correctly without, e.g., deadlocking because a pipe is full).
739         match transfer_direction {
740             VhostUserTransferDirection::Save => {
741                 // Snapshot the state.
742                 let snapshot = DeviceRequestHandlerSnapshot {
743                     acked_features: self.acked_features,
744                     acked_protocol_features: self.acked_protocol_features.bits(),
745                     backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
746                 };
747                 // Spawn thread to write the serialized bytes.
748                 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
749                     "device_state_save",
750                     move |_kill_event| serde_json::to_writer(&mut fd, &snapshot),
751                 )));
752                 Ok(None)
753             }
754             VhostUserTransferDirection::Load => {
755                 // Spawn a thread to read the bytes and deserialize. Restore will happen in
756                 // `check_device_state`.
757                 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
758                     "device_state_load",
759                     move |_kill_event| serde_json::from_reader(&mut fd),
760                 )));
761                 Ok(None)
762             }
763         }
764     }
765 
check_device_state(&mut self) -> VhostResult<()>766     fn check_device_state(&mut self) -> VhostResult<()> {
767         let Some(thread) = self.device_state_thread.take() else {
768             error!("check_device_state: no active state transfer");
769             return Err(VhostError::InvalidOperation);
770         };
771         match thread {
772             DeviceStateThread::Save(worker) => {
773                 worker.stop().map_err(|e| {
774                     error!("device state save thread failed: {:#}", e);
775                     VhostError::BackendInternalError
776                 })?;
777                 Ok(())
778             }
779             DeviceStateThread::Load(worker) => {
780                 let snapshot = worker.stop().map_err(|e| {
781                     error!("device state load thread failed: {:#}", e);
782                     VhostError::BackendInternalError
783                 })?;
784                 self.acked_features = snapshot.acked_features;
785                 self.acked_protocol_features =
786                     VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
787                         .with_context(|| {
788                             format!(
789                                 "unsupported bits in acked_protocol_features: {:#x}",
790                                 snapshot.acked_protocol_features
791                             )
792                         })
793                         .map_err(VhostError::RestoreError)?;
794                 self.backend
795                     .restore(snapshot.backend)
796                     .map_err(VhostError::RestoreError)?;
797                 Ok(())
798             }
799         }
800     }
801 
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>802     fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
803         Ok(if let Some(r) = self.backend.get_shared_memory_region() {
804             vec![VhostSharedMemoryRegion::new(r.id, r.length)]
805         } else {
806             Vec::new()
807         })
808     }
809 }
810 
811 /// Indicates the state of backend request connection
812 pub enum VhostBackendReqConnectionState {
813     /// A backend request connection (`VhostBackendReqConnection`) is established
814     Connected(Arc<VhostBackendReqConnection>),
815     /// No backend request connection has been established yet
816     NoConnection,
817 }
818 
819 /// Keeps track of Vhost user backend request connection.
820 pub struct VhostBackendReqConnection {
821     conn: Arc<Mutex<FrontendClient>>,
822     shmem_info: Mutex<Option<ShmemInfo>>,
823 }
824 
825 #[derive(Clone)]
826 struct ShmemInfo {
827     shmid: u8,
828     mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
829 }
830 
831 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self832     pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
833         let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
834             shmid,
835             mapped_regions: BTreeMap::new(),
836         }));
837         Self {
838             conn: Arc::new(Mutex::new(conn)),
839             shmem_info,
840         }
841     }
842 
843     /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>844     pub fn send_config_changed(&self) -> anyhow::Result<()> {
845         self.conn
846             .lock()
847             .handle_config_change()
848             .context("Could not send config change message")?;
849         Ok(())
850     }
851 
852     /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>853     pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
854         let shmem_info = self
855             .shmem_info
856             .lock()
857             .take()
858             .context("could not take shared memory mapper information")?;
859 
860         Ok(Box::new(VhostShmemMapper {
861             conn: self.conn.clone(),
862             shmem_info,
863         }))
864     }
865 }
866 
867 struct VhostShmemMapper {
868     conn: Arc<Mutex<FrontendClient>>,
869     shmem_info: ShmemInfo,
870 }
871 
872 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>873     fn add_mapping(
874         &mut self,
875         source: VmMemorySource,
876         offset: u64,
877         prot: Protection,
878         _cache: MemCacheType,
879     ) -> anyhow::Result<()> {
880         let size = match source {
881             VmMemorySource::Vulkan {
882                 descriptor,
883                 handle_type,
884                 memory_idx,
885                 device_uuid,
886                 driver_uuid,
887                 size,
888             } => {
889                 let msg = VhostUserGpuMapMsg::new(
890                     self.shmem_info.shmid,
891                     offset,
892                     size,
893                     memory_idx,
894                     handle_type,
895                     device_uuid,
896                     driver_uuid,
897                 );
898                 self.conn
899                     .lock()
900                     .gpu_map(&msg, &descriptor)
901                     .context("failed to map memory")?;
902                 size
903             }
904             VmMemorySource::ExternalMapping { ptr, size } => {
905                 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
906                 self.conn
907                     .lock()
908                     .external_map(&msg)
909                     .context("failed to map memory")?;
910                 size
911             }
912             source => {
913                 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
914                 // on the aliased `source` above.
915                 let (descriptor, fd_offset, size) = match source {
916                     VmMemorySource::Descriptor {
917                         descriptor,
918                         offset,
919                         size,
920                     } => (descriptor, offset, size),
921                     VmMemorySource::SharedMemory(shmem) => {
922                         let size = shmem.size();
923                         let descriptor = SafeDescriptor::from(shmem);
924                         (descriptor, 0, size)
925                     }
926                     _ => bail!("unsupported source"),
927                 };
928                 let flags = VhostUserShmemMapMsgFlags::from(prot);
929                 let msg = VhostUserShmemMapMsg::new(
930                     self.shmem_info.shmid,
931                     offset,
932                     fd_offset,
933                     size,
934                     flags,
935                 );
936                 self.conn
937                     .lock()
938                     .shmem_map(&msg, &descriptor)
939                     .context("failed to map memory")?;
940                 size
941             }
942         };
943 
944         self.shmem_info.mapped_regions.insert(offset, size);
945         Ok(())
946     }
947 
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>948     fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
949         let size = self
950             .shmem_info
951             .mapped_regions
952             .remove(&offset)
953             .context("unknown offset")?;
954         let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
955         self.conn
956             .lock()
957             .shmem_unmap(&msg)
958             .context("failed to map memory")
959             .map(|_| ())
960     }
961 }
962 
963 pub(crate) struct WorkerState<T, U> {
964     pub(crate) queue_task: TaskHandle<U>,
965     pub(crate) queue: T,
966 }
967 
968 /// Errors for device operations
969 #[derive(Debug, ThisError)]
970 pub enum Error {
971     #[error("worker not found when stopping queue")]
972     WorkerNotFound,
973 }
974 
975 #[cfg(test)]
976 mod tests {
977     use std::sync::mpsc::channel;
978     use std::sync::Barrier;
979 
980     use anyhow::bail;
981     use base::Event;
982     use vmm_vhost::BackendServer;
983     use vmm_vhost::FrontendReq;
984     use zerocopy::FromBytes;
985     use zerocopy::FromZeros;
986     use zerocopy::Immutable;
987     use zerocopy::IntoBytes;
988     use zerocopy::KnownLayout;
989 
990     use super::*;
991     use crate::virtio::vhost_user_frontend::VhostUserFrontend;
992     use crate::virtio::DeviceType;
993     use crate::virtio::VirtioDevice;
994 
995     #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
996     #[repr(C, packed(4))]
997     struct FakeConfig {
998         x: u32,
999         y: u64,
1000     }
1001 
1002     const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1003 
1004     pub(super) struct FakeBackend {
1005         avail_features: u64,
1006         acked_features: u64,
1007         active_queues: Vec<Option<Queue>>,
1008         allow_backend_req: bool,
1009         backend_conn: Option<Arc<VhostBackendReqConnection>>,
1010     }
1011 
1012     #[derive(Deserialize, Serialize)]
1013     struct FakeBackendSnapshot {
1014         data: Vec<u8>,
1015     }
1016 
1017     impl FakeBackend {
1018         const MAX_QUEUE_NUM: usize = 16;
1019 
new() -> Self1020         pub(super) fn new() -> Self {
1021             let mut active_queues = Vec::new();
1022             active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1023             Self {
1024                 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1025                 acked_features: 0,
1026                 active_queues,
1027                 allow_backend_req: false,
1028                 backend_conn: None,
1029             }
1030         }
1031     }
1032 
1033     impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1034         fn max_queue_num(&self) -> usize {
1035             Self::MAX_QUEUE_NUM
1036         }
1037 
features(&self) -> u641038         fn features(&self) -> u64 {
1039             self.avail_features
1040         }
1041 
ack_features(&mut self, value: u64) -> anyhow::Result<()>1042         fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1043             let unrequested_features = value & !self.avail_features;
1044             if unrequested_features != 0 {
1045                 bail!(
1046                     "invalid protocol features are given: 0x{:x}",
1047                     unrequested_features
1048                 );
1049             }
1050             self.acked_features |= value;
1051             Ok(())
1052         }
1053 
protocol_features(&self) -> VhostUserProtocolFeatures1054         fn protocol_features(&self) -> VhostUserProtocolFeatures {
1055             let mut features =
1056                 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1057             if self.allow_backend_req {
1058                 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1059             }
1060             features
1061         }
1062 
read_config(&self, offset: u64, dst: &mut [u8])1063         fn read_config(&self, offset: u64, dst: &mut [u8]) {
1064             dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1065         }
1066 
reset(&mut self)1067         fn reset(&mut self) {}
1068 
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, ) -> anyhow::Result<()>1069         fn start_queue(
1070             &mut self,
1071             idx: usize,
1072             queue: Queue,
1073             _mem: GuestMemory,
1074         ) -> anyhow::Result<()> {
1075             self.active_queues[idx] = Some(queue);
1076             Ok(())
1077         }
1078 
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1079         fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1080             Ok(self.active_queues[idx]
1081                 .take()
1082                 .ok_or(Error::WorkerNotFound)?)
1083         }
1084 
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1085         fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1086             self.backend_conn = Some(conn);
1087         }
1088 
enter_suspended_state(&mut self) -> anyhow::Result<()>1089         fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1090             Ok(())
1091         }
1092 
snapshot(&mut self) -> anyhow::Result<AnySnapshot>1093         fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1094             AnySnapshot::to_any(FakeBackendSnapshot {
1095                 data: vec![1, 2, 3],
1096             })
1097             .context("failed to serialize snapshot")
1098         }
1099 
restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>1100         fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1101             let snapshot: FakeBackendSnapshot =
1102                 AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1103             assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1104             Ok(())
1105         }
1106     }
1107 
1108     #[test]
test_vhost_user_lifecycle()1109     fn test_vhost_user_lifecycle() {
1110         test_vhost_user_lifecycle_parameterized(false);
1111     }
1112 
1113     #[test]
1114     #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_lifecycle_with_backend_req()1115     fn test_vhost_user_lifecycle_with_backend_req() {
1116         test_vhost_user_lifecycle_parameterized(true);
1117     }
1118 
test_vhost_user_lifecycle_parameterized(allow_backend_req: bool)1119     fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1120         const QUEUES_NUM: usize = 2;
1121 
1122         let (client_connection, server_connection) =
1123             vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1124 
1125         let vmm_bar = Arc::new(Barrier::new(2));
1126         let dev_bar = vmm_bar.clone();
1127 
1128         let (ready_tx, ready_rx) = channel();
1129         let (shutdown_tx, shutdown_rx) = channel();
1130 
1131         std::thread::spawn(move || {
1132             // VMM side
1133             ready_rx.recv().unwrap(); // Ensure the device is ready.
1134 
1135             let mut vmm_device =
1136                 VhostUserFrontend::new(DeviceType::Console, 0, client_connection, None, None)
1137                     .unwrap();
1138 
1139             println!("read_config");
1140             let mut config = FakeConfig::new_zeroed();
1141             vmm_device.read_config(0, config.as_mut_bytes());
1142             // Check if the obtained config data is correct.
1143             assert_eq!(config, FAKE_CONFIG_DATA);
1144 
1145             let activate = |vmm_device: &mut VhostUserFrontend| {
1146                 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1147                 let interrupt = Interrupt::new_for_test_with_msix();
1148 
1149                 let mut queues = BTreeMap::new();
1150                 for idx in 0..QUEUES_NUM {
1151                     let mut queue = QueueConfig::new(0x10, 0);
1152                     queue.set_ready(true);
1153                     let queue = queue
1154                         .activate(&mem, Event::new().unwrap(), interrupt.clone())
1155                         .expect("QueueConfig::activate");
1156                     queues.insert(idx, queue);
1157                 }
1158 
1159                 println!("activate");
1160                 vmm_device.activate(mem, interrupt, queues).unwrap();
1161             };
1162 
1163             activate(&mut vmm_device);
1164 
1165             println!("reset");
1166             let reset_result = vmm_device.reset();
1167             assert!(
1168                 reset_result.is_ok(),
1169                 "reset failed: {:#}",
1170                 reset_result.unwrap_err()
1171             );
1172 
1173             activate(&mut vmm_device);
1174 
1175             println!("virtio_sleep");
1176             let queues = vmm_device
1177                 .virtio_sleep()
1178                 .unwrap()
1179                 .expect("virtio_sleep unexpectedly returned None");
1180 
1181             println!("virtio_snapshot");
1182             let snapshot = vmm_device
1183                 .virtio_snapshot()
1184                 .expect("virtio_snapshot failed");
1185             println!("virtio_restore");
1186             vmm_device
1187                 .virtio_restore(snapshot)
1188                 .expect("virtio_restore failed");
1189 
1190             println!("virtio_wake");
1191             let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1192             let interrupt = Interrupt::new_for_test_with_msix();
1193             vmm_device
1194                 .virtio_wake(Some((mem, interrupt, queues)))
1195                 .unwrap();
1196 
1197             println!("wait for shutdown signal");
1198             shutdown_rx.recv().unwrap();
1199 
1200             // The VMM side is supposed to stop before the device side.
1201             println!("drop");
1202             drop(vmm_device);
1203 
1204             vmm_bar.wait();
1205         });
1206 
1207         // Device side
1208         let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1209         handler.as_mut().allow_backend_req = allow_backend_req;
1210 
1211         // Notify listener is ready.
1212         ready_tx.send(()).unwrap();
1213 
1214         let mut req_handler = BackendServer::new(server_connection, handler);
1215 
1216         // VhostUserFrontend::new()
1217         handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1218         handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1219         handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1220         handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1221         handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1222         if allow_backend_req {
1223             handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1224         }
1225 
1226         // VhostUserFrontend::read_config()
1227         handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1228 
1229         // VhostUserFrontend::activate()
1230         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1231         for _ in 0..QUEUES_NUM {
1232             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1233             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1234             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1235             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1236             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1237             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1238         }
1239 
1240         // VhostUserFrontend::reset()
1241         for _ in 0..QUEUES_NUM {
1242             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1243             handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1244         }
1245 
1246         // VhostUserFrontend::activate()
1247         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1248         for _ in 0..QUEUES_NUM {
1249             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1250             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1251             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1252             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1253             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1254             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1255         }
1256 
1257         if allow_backend_req {
1258             // Make sure the connection still works even after reset/reactivate.
1259             req_handler
1260                 .as_ref()
1261                 .as_ref()
1262                 .backend_conn
1263                 .as_ref()
1264                 .expect("backend_conn missing")
1265                 .send_config_changed()
1266                 .expect("send_config_changed failed");
1267         }
1268 
1269         // VhostUserFrontend::virtio_sleep()
1270         for _ in 0..QUEUES_NUM {
1271             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1272             handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1273         }
1274 
1275         // VhostUserFrontend::virtio_snapshot()
1276         handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1277         handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1278         // VhostUserFrontend::virtio_restore()
1279         handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1280         handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1281 
1282         // VhostUserFrontend::virtio_wake()
1283         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1284         for _ in 0..QUEUES_NUM {
1285             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1286             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1287             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1288             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1289             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1290             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1291         }
1292 
1293         if allow_backend_req {
1294             // Make sure the connection still works even after sleep/wake.
1295             req_handler
1296                 .as_ref()
1297                 .as_ref()
1298                 .backend_conn
1299                 .as_ref()
1300                 .expect("backend_conn missing")
1301                 .send_config_changed()
1302                 .expect("send_config_changed failed");
1303         }
1304 
1305         // Ask the client to shutdown, then wait to it to finish.
1306         shutdown_tx.send(()).unwrap();
1307         dev_bar.wait();
1308 
1309         // Verify recv_header fails with `ClientExit` after the client has disconnected.
1310         match req_handler.recv_header() {
1311             Err(VhostError::ClientExit) => (),
1312             r => panic!("expected Err(ClientExit) but got {:?}", r),
1313         }
1314     }
1315 
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1316     fn handle_request<S: vmm_vhost::Backend>(
1317         handler: &mut BackendServer<S>,
1318         expected_message_type: FrontendReq,
1319     ) -> Result<(), VhostError> {
1320         let (hdr, files) = handler.recv_header()?;
1321         assert_eq!(hdr.get_code(), Ok(expected_message_type));
1322         handler.process_message(hdr, files)
1323     }
1324 }
1325