1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //! /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //! /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //! let backend = MyBackend { /* initialize fields */ };
28 //! let handler = DeviceRequestHandler::new(backend);
29 //! let socket = std::path::Path("/path/to/socket");
30 //! let ex = cros_async::Executor::new()?;
31 //!
32 //! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //! eprintln!("error happened: {}", e);
34 //! }
35 //! Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46
47 pub(super) mod sys;
48
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::trace;
63 use base::warn;
64 use base::Event;
65 use base::Protection;
66 use base::SafeDescriptor;
67 use base::SharedMemory;
68 use base::WorkerThread;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use snapshot::AnySnapshot;
74 use sync::Mutex;
75 use thiserror::Error as ThisError;
76 use vm_control::VmMemorySource;
77 use vm_memory::GuestAddress;
78 use vm_memory::GuestMemory;
79 use vm_memory::MemoryRegion;
80 use vmm_vhost::message::VhostSharedMemoryRegion;
81 use vmm_vhost::message::VhostUserConfigFlags;
82 use vmm_vhost::message::VhostUserExternalMapMsg;
83 use vmm_vhost::message::VhostUserGpuMapMsg;
84 use vmm_vhost::message::VhostUserInflight;
85 use vmm_vhost::message::VhostUserMemoryRegion;
86 use vmm_vhost::message::VhostUserMigrationPhase;
87 use vmm_vhost::message::VhostUserProtocolFeatures;
88 use vmm_vhost::message::VhostUserShmemMapMsg;
89 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
90 use vmm_vhost::message::VhostUserShmemUnmapMsg;
91 use vmm_vhost::message::VhostUserSingleMemoryRegion;
92 use vmm_vhost::message::VhostUserTransferDirection;
93 use vmm_vhost::message::VhostUserVringAddrFlags;
94 use vmm_vhost::message::VhostUserVringState;
95 use vmm_vhost::BackendReq;
96 use vmm_vhost::Connection;
97 use vmm_vhost::Error as VhostError;
98 use vmm_vhost::Frontend;
99 use vmm_vhost::FrontendClient;
100 use vmm_vhost::Result as VhostResult;
101 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
102
103 use crate::virtio::Interrupt;
104 use crate::virtio::Queue;
105 use crate::virtio::QueueConfig;
106 use crate::virtio::SharedMemoryMapper;
107 use crate::virtio::SharedMemoryRegion;
108
109 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
110 /// used to translate messages from the vmm to guest offsets.
111 #[derive(Default)]
112 pub struct MappingInfo {
113 pub vmm_addr: u64,
114 pub guest_phys: u64,
115 pub size: u64,
116 }
117
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>118 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
119 for map in maps {
120 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
121 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
122 }
123 }
124 Err(VhostError::InvalidMessage)
125 }
126
127 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
128 ///
129 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
130 /// is designed to follow crosvm conventions for implementing devices.
131 pub trait VhostUserDevice {
132 /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize133 fn max_queue_num(&self) -> usize;
134
135 /// The set of feature bits that this backend supports.
features(&self) -> u64136 fn features(&self) -> u64;
137
138 /// Acknowledges that this set of features should be enabled.
139 ///
140 /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
141 /// framework will manage generic vhost and vring features.
142 ///
143 /// `DeviceRequestHandler` checks for valid features before calling this function, so the
144 /// features in `value` will always be a subset of those advertised by `features()`.
ack_features(&mut self, _value: u64) -> anyhow::Result<()>145 fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
146 Ok(())
147 }
148
149 /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures150 fn protocol_features(&self) -> VhostUserProtocolFeatures;
151
152 /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])153 fn read_config(&self, offset: u64, dst: &mut [u8]);
154
155 /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])156 fn write_config(&self, _offset: u64, _data: &[u8]) {}
157
158 /// Indicates that the backend should start processing requests for virtio queue number `idx`.
159 /// This method must not block the current thread so device backends should either spawn an
160 /// async task or another thread to handle messages from the Queue.
start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>161 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
162
163 /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
164 /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
165 /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>166 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
167
168 /// Resets the vhost-user backend.
reset(&mut self)169 fn reset(&mut self);
170
171 /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>172 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
173 None
174 }
175
176 /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
177 /// handling.
178 ///
179 /// The backend is given an `Arc` instead of full ownership so that the framework can also use
180 /// the connection.
181 ///
182 /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
183 /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)184 fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {}
185
186 /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
187 /// requirements.
188 ///
189 /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
190 ///
191 /// Called after a `stop_queue` call if there are no running queues left. Also called soon
192 /// after device creation to ensure the device is acting suspended immediately on construction.
193 ///
194 /// The next `start_queue` call implicitly exits the "suspend device state".
195 ///
196 /// * Ok(()) => device successfully suspended
197 /// * Err(_) => unrecoverable error
enter_suspended_state(&mut self) -> anyhow::Result<()>198 fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
199
200 /// Snapshot device and return serialized state.
snapshot(&mut self) -> anyhow::Result<AnySnapshot>201 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot>;
202
203 /// Restore device state from a snapshot.
restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>204 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>;
205 }
206
207 /// A virtio ring entry.
208 struct Vring {
209 // The queue config. This doesn't get mutated by the queue workers.
210 queue: QueueConfig,
211 doorbell: Option<Interrupt>,
212 enabled: bool,
213 }
214
215 impl Vring {
new(max_size: u16, features: u64) -> Self216 fn new(max_size: u16, features: u64) -> Self {
217 Self {
218 queue: QueueConfig::new(max_size, features),
219 doorbell: None,
220 enabled: false,
221 }
222 }
223
reset(&mut self)224 fn reset(&mut self) {
225 self.queue.reset();
226 self.doorbell = None;
227 self.enabled = false;
228 }
229 }
230
231 /// Ops for running vhost-user over a stream (i.e. regular protocol).
232 pub(super) struct VhostUserRegularOps;
233
234 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>235 pub fn set_mem_table(
236 contexts: &[VhostUserMemoryRegion],
237 files: Vec<File>,
238 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
239 if files.len() != contexts.len() {
240 return Err(VhostError::InvalidParam(
241 "number of files & contexts was not equal",
242 ));
243 }
244
245 let mut regions = Vec::with_capacity(files.len());
246 for (region, file) in contexts.iter().zip(files.into_iter()) {
247 let region = MemoryRegion::new_from_shm(
248 region.memory_size,
249 GuestAddress(region.guest_phys_addr),
250 region.mmap_offset,
251 Arc::new(
252 SharedMemory::from_safe_descriptor(
253 SafeDescriptor::from(file),
254 region.memory_size,
255 )
256 .unwrap(),
257 ),
258 )
259 .map_err(|e| {
260 error!("failed to create a memory region: {}", e);
261 VhostError::InvalidOperation
262 })?;
263 regions.push(region);
264 }
265 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
266 error!("failed to create guest memory: {}", e);
267 VhostError::InvalidOperation
268 })?;
269
270 let vmm_maps = contexts
271 .iter()
272 .map(|region| MappingInfo {
273 vmm_addr: region.user_addr,
274 guest_phys: region.guest_phys_addr,
275 size: region.memory_size,
276 })
277 .collect();
278 Ok((guest_mem, vmm_maps))
279 }
280 }
281
282 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
283 pub struct DeviceRequestHandler<T: VhostUserDevice> {
284 vrings: Vec<Vring>,
285 owned: bool,
286 vmm_maps: Option<Vec<MappingInfo>>,
287 mem: Option<GuestMemory>,
288 acked_features: u64,
289 acked_protocol_features: VhostUserProtocolFeatures,
290 backend: T,
291 backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
292 // Thread processing active device state FD.
293 device_state_thread: Option<DeviceStateThread>,
294 }
295
296 enum DeviceStateThread {
297 Save(WorkerThread<serde_json::Result<()>>),
298 Load(WorkerThread<serde_json::Result<DeviceRequestHandlerSnapshot>>),
299 }
300
301 #[derive(Serialize, Deserialize)]
302 pub struct DeviceRequestHandlerSnapshot {
303 acked_features: u64,
304 acked_protocol_features: u64,
305 backend: AnySnapshot,
306 }
307
308 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
309 /// Creates a vhost-user handler instance for `backend`.
new(mut backend: T) -> Self310 pub(crate) fn new(mut backend: T) -> Self {
311 let mut vrings = Vec::with_capacity(backend.max_queue_num());
312 for _ in 0..backend.max_queue_num() {
313 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
314 }
315
316 // VhostUserDevice implementations must support `enter_suspended_state()`.
317 // Call it on startup to ensure it works and to initialize the device in a suspended state.
318 backend
319 .enter_suspended_state()
320 .expect("enter_suspended_state failed on device init");
321
322 DeviceRequestHandler {
323 vrings,
324 owned: false,
325 vmm_maps: None,
326 mem: None,
327 acked_features: 0,
328 acked_protocol_features: VhostUserProtocolFeatures::empty(),
329 backend,
330 backend_req_connection: Arc::new(Mutex::new(
331 VhostBackendReqConnectionState::NoConnection,
332 )),
333 device_state_thread: None,
334 }
335 }
336
337 /// Check if all queues are stopped.
338 ///
339 /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
all_queues_stopped(&self) -> bool340 fn all_queues_stopped(&self) -> bool {
341 self.vrings.iter().all(|vring| !vring.queue.ready())
342 }
343 }
344
345 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T346 fn as_ref(&self) -> &T {
347 &self.backend
348 }
349 }
350
351 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T352 fn as_mut(&mut self) -> &mut T {
353 &mut self.backend
354 }
355 }
356
357 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>358 fn set_owner(&mut self) -> VhostResult<()> {
359 if self.owned {
360 return Err(VhostError::InvalidOperation);
361 }
362 self.owned = true;
363 Ok(())
364 }
365
reset_owner(&mut self) -> VhostResult<()>366 fn reset_owner(&mut self) -> VhostResult<()> {
367 self.owned = false;
368 self.acked_features = 0;
369 self.backend.reset();
370 Ok(())
371 }
372
get_features(&mut self) -> VhostResult<u64>373 fn get_features(&mut self) -> VhostResult<u64> {
374 let features = self.backend.features();
375 Ok(features)
376 }
377
set_features(&mut self, features: u64) -> VhostResult<()>378 fn set_features(&mut self, features: u64) -> VhostResult<()> {
379 if !self.owned {
380 return Err(VhostError::InvalidOperation);
381 }
382
383 let unexpected_features = features & !self.backend.features();
384 if unexpected_features != 0 {
385 error!("unexpected set_features {:#x}", unexpected_features);
386 return Err(VhostError::InvalidParam("unexpected set_features"));
387 }
388
389 if let Err(e) = self.backend.ack_features(features) {
390 error!("failed to acknowledge features 0x{:x}: {}", features, e);
391 return Err(VhostError::InvalidOperation);
392 }
393
394 self.acked_features |= features;
395
396 // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
397 // enabled state.
398 // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
399 // disabled state.
400 // Client must not pass data to/from the backend until ring is enabled by
401 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
402 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
403 let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
404 for v in &mut self.vrings {
405 v.enabled = vring_enabled;
406 }
407
408 Ok(())
409 }
410
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>411 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
412 Ok(self.backend.protocol_features())
413 }
414
set_protocol_features(&mut self, features: u64) -> VhostResult<()>415 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
416 let features = match VhostUserProtocolFeatures::from_bits(features) {
417 Some(proto_features) => proto_features,
418 None => {
419 error!(
420 "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
421 features
422 );
423 return Err(VhostError::InvalidOperation);
424 }
425 };
426 let supported = self.backend.protocol_features();
427 self.acked_protocol_features = features & supported;
428 Ok(())
429 }
430
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>431 fn set_mem_table(
432 &mut self,
433 contexts: &[VhostUserMemoryRegion],
434 files: Vec<File>,
435 ) -> VhostResult<()> {
436 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
437 self.mem = Some(guest_mem);
438 self.vmm_maps = Some(vmm_maps);
439 Ok(())
440 }
441
get_queue_num(&mut self) -> VhostResult<u64>442 fn get_queue_num(&mut self) -> VhostResult<u64> {
443 Ok(self.vrings.len() as u64)
444 }
445
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>446 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
447 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
448 return Err(VhostError::InvalidParam(
449 "set_vring_num: invalid index or num",
450 ));
451 }
452 self.vrings[index as usize].queue.set_size(num as u16);
453
454 Ok(())
455 }
456
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>457 fn set_vring_addr(
458 &mut self,
459 index: u32,
460 _flags: VhostUserVringAddrFlags,
461 descriptor: u64,
462 used: u64,
463 available: u64,
464 _log: u64,
465 ) -> VhostResult<()> {
466 if index as usize >= self.vrings.len() {
467 return Err(VhostError::InvalidParam(
468 "set_vring_addr: index out of range",
469 ));
470 }
471
472 let vmm_maps = self
473 .vmm_maps
474 .as_ref()
475 .ok_or(VhostError::InvalidParam("set_vring_addr: missing vmm_maps"))?;
476 let vring = &mut self.vrings[index as usize];
477 vring
478 .queue
479 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
480 vring
481 .queue
482 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
483 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
484
485 Ok(())
486 }
487
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>488 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
489 if index as usize >= self.vrings.len() {
490 return Err(VhostError::InvalidParam(
491 "set_vring_base: index out of range",
492 ));
493 }
494
495 let vring = &mut self.vrings[index as usize];
496 vring.queue.set_next_avail(Wrapping(base as u16));
497 vring.queue.set_next_used(Wrapping(base as u16));
498
499 Ok(())
500 }
501
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>502 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
503 let vring = self
504 .vrings
505 .get_mut(index as usize)
506 .ok_or(VhostError::InvalidParam(
507 "get_vring_base: index out of range",
508 ))?;
509
510 // Quotation from vhost-user spec:
511 // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
512 // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
513 // means it is started and should be stopped.
514 let vring_base = if vring.queue.ready() {
515 let queue = match self.backend.stop_queue(index as usize) {
516 Ok(q) => q,
517 Err(e) => {
518 error!("Failed to stop queue in get_vring_base: {:#}", e);
519 return Err(VhostError::BackendInternalError);
520 }
521 };
522
523 trace!("stopped queue {index}");
524 vring.reset();
525
526 if self.all_queues_stopped() {
527 trace!("all queues stopped; entering suspended state");
528 self.backend
529 .enter_suspended_state()
530 .map_err(VhostError::EnterSuspendedState)?;
531 }
532
533 queue.next_avail_to_process()
534 } else {
535 0
536 };
537
538 Ok(VhostUserVringState::new(index, vring_base.into()))
539 }
540
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>541 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
542 if index as usize >= self.vrings.len() {
543 return Err(VhostError::InvalidParam(
544 "set_vring_kick: index out of range",
545 ));
546 }
547
548 let vring = &mut self.vrings[index as usize];
549 if vring.queue.ready() {
550 error!("kick fd cannot replaced after queue is started");
551 return Err(VhostError::InvalidOperation);
552 }
553
554 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_kick"))?;
555
556 // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
557 // values via `next_val()` later.
558 // This is only required (and can only be done) on Unix platforms.
559 #[cfg(any(target_os = "android", target_os = "linux"))]
560 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
561 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
562 return Err(VhostError::InvalidParam(
563 "could not remove O_NONBLOCK from vring_kick",
564 ));
565 }
566
567 let kick_evt = Event::from(SafeDescriptor::from(file));
568
569 // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
570 vring.queue.ack_features(self.acked_features);
571 vring.queue.set_ready(true);
572
573 let mem = self
574 .mem
575 .as_ref()
576 .cloned()
577 .ok_or(VhostError::InvalidOperation)?;
578
579 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
580
581 let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
582 Ok(queue) => queue,
583 Err(e) => {
584 error!("failed to activate vring: {:#}", e);
585 return Err(VhostError::BackendInternalError);
586 }
587 };
588
589 if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
590 error!("Failed to start queue {}: {}", index, e);
591 return Err(VhostError::BackendInternalError);
592 }
593 trace!("started queue {index}");
594
595 Ok(())
596 }
597
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>598 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
599 if index as usize >= self.vrings.len() {
600 return Err(VhostError::InvalidParam(
601 "set_vring_call: index out of range",
602 ));
603 }
604
605 let backend_req_conn = self.backend_req_connection.clone();
606 let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
607 VhostBackendReqConnectionState::Connected(frontend) => {
608 if let Err(e) = frontend.send_config_changed() {
609 error!("Failed to notify config change: {:#}", e);
610 }
611 }
612 VhostBackendReqConnectionState::NoConnection => {
613 error!("No Backend request connection found");
614 }
615 });
616
617 let file = file.ok_or(VhostError::InvalidParam("missing file for set_vring_call"))?;
618 self.vrings[index as usize].doorbell = Some(Interrupt::new_vhost_user(
619 Event::from(SafeDescriptor::from(file)),
620 signal_config_change_fn,
621 ));
622 Ok(())
623 }
624
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>625 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
626 // TODO
627 Ok(())
628 }
629
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>630 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
631 if index as usize >= self.vrings.len() {
632 return Err(VhostError::InvalidParam(
633 "set_vring_enable: index out of range",
634 ));
635 }
636
637 // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
638 // has been negotiated.
639 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
640 return Err(VhostError::InvalidOperation);
641 }
642
643 // Backend must not pass data to/from the ring until ring is enabled by
644 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
645 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
646 self.vrings[index as usize].enabled = enable;
647
648 Ok(())
649 }
650
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>651 fn get_config(
652 &mut self,
653 offset: u32,
654 size: u32,
655 _flags: VhostUserConfigFlags,
656 ) -> VhostResult<Vec<u8>> {
657 let mut data = vec![0; size as usize];
658 self.backend.read_config(u64::from(offset), &mut data);
659 Ok(data)
660 }
661
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>662 fn set_config(
663 &mut self,
664 offset: u32,
665 buf: &[u8],
666 _flags: VhostUserConfigFlags,
667 ) -> VhostResult<()> {
668 self.backend.write_config(u64::from(offset), buf);
669 Ok(())
670 }
671
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)672 fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
673 let conn = Arc::new(VhostBackendReqConnection::new(
674 FrontendClient::new(ep),
675 self.backend.get_shared_memory_region().map(|r| r.id),
676 ));
677
678 {
679 let mut backend_req_conn = self.backend_req_connection.lock();
680 if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
681 warn!("Backend Request Connection already established. Overwriting");
682 }
683 *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
684 }
685
686 self.backend.set_backend_req_connection(conn);
687 }
688
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>689 fn get_inflight_fd(
690 &mut self,
691 _inflight: &VhostUserInflight,
692 ) -> VhostResult<(VhostUserInflight, File)> {
693 unimplemented!("get_inflight_fd");
694 }
695
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>696 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
697 unimplemented!("set_inflight_fd");
698 }
699
get_max_mem_slots(&mut self) -> VhostResult<u64>700 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
701 //TODO
702 Ok(0)
703 }
704
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>705 fn add_mem_region(
706 &mut self,
707 _region: &VhostUserSingleMemoryRegion,
708 _fd: File,
709 ) -> VhostResult<()> {
710 //TODO
711 Ok(())
712 }
713
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>714 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
715 //TODO
716 Ok(())
717 }
718
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, mut fd: File, ) -> VhostResult<Option<File>>719 fn set_device_state_fd(
720 &mut self,
721 transfer_direction: VhostUserTransferDirection,
722 migration_phase: VhostUserMigrationPhase,
723 mut fd: File,
724 ) -> VhostResult<Option<File>> {
725 if migration_phase != VhostUserMigrationPhase::Stopped {
726 return Err(VhostError::InvalidOperation);
727 }
728 if !self.all_queues_stopped() {
729 return Err(VhostError::InvalidOperation);
730 }
731 if self.device_state_thread.is_some() {
732 error!("must call check_device_state before starting new state transfer");
733 return Err(VhostError::InvalidOperation);
734 }
735 // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
736 // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
737 // handle the serialization and data transfer (the latter which seems necessary to
738 // implement the API correctly without, e.g., deadlocking because a pipe is full).
739 match transfer_direction {
740 VhostUserTransferDirection::Save => {
741 // Snapshot the state.
742 let snapshot = DeviceRequestHandlerSnapshot {
743 acked_features: self.acked_features,
744 acked_protocol_features: self.acked_protocol_features.bits(),
745 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
746 };
747 // Spawn thread to write the serialized bytes.
748 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
749 "device_state_save",
750 move |_kill_event| serde_json::to_writer(&mut fd, &snapshot),
751 )));
752 Ok(None)
753 }
754 VhostUserTransferDirection::Load => {
755 // Spawn a thread to read the bytes and deserialize. Restore will happen in
756 // `check_device_state`.
757 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
758 "device_state_load",
759 move |_kill_event| serde_json::from_reader(&mut fd),
760 )));
761 Ok(None)
762 }
763 }
764 }
765
check_device_state(&mut self) -> VhostResult<()>766 fn check_device_state(&mut self) -> VhostResult<()> {
767 let Some(thread) = self.device_state_thread.take() else {
768 error!("check_device_state: no active state transfer");
769 return Err(VhostError::InvalidOperation);
770 };
771 match thread {
772 DeviceStateThread::Save(worker) => {
773 worker.stop().map_err(|e| {
774 error!("device state save thread failed: {:#}", e);
775 VhostError::BackendInternalError
776 })?;
777 Ok(())
778 }
779 DeviceStateThread::Load(worker) => {
780 let snapshot = worker.stop().map_err(|e| {
781 error!("device state load thread failed: {:#}", e);
782 VhostError::BackendInternalError
783 })?;
784 self.acked_features = snapshot.acked_features;
785 self.acked_protocol_features =
786 VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
787 .with_context(|| {
788 format!(
789 "unsupported bits in acked_protocol_features: {:#x}",
790 snapshot.acked_protocol_features
791 )
792 })
793 .map_err(VhostError::RestoreError)?;
794 self.backend
795 .restore(snapshot.backend)
796 .map_err(VhostError::RestoreError)?;
797 Ok(())
798 }
799 }
800 }
801
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>802 fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
803 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
804 vec![VhostSharedMemoryRegion::new(r.id, r.length)]
805 } else {
806 Vec::new()
807 })
808 }
809 }
810
811 /// Indicates the state of backend request connection
812 pub enum VhostBackendReqConnectionState {
813 /// A backend request connection (`VhostBackendReqConnection`) is established
814 Connected(Arc<VhostBackendReqConnection>),
815 /// No backend request connection has been established yet
816 NoConnection,
817 }
818
819 /// Keeps track of Vhost user backend request connection.
820 pub struct VhostBackendReqConnection {
821 conn: Arc<Mutex<FrontendClient>>,
822 shmem_info: Mutex<Option<ShmemInfo>>,
823 }
824
825 #[derive(Clone)]
826 struct ShmemInfo {
827 shmid: u8,
828 mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
829 }
830
831 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self832 pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
833 let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
834 shmid,
835 mapped_regions: BTreeMap::new(),
836 }));
837 Self {
838 conn: Arc::new(Mutex::new(conn)),
839 shmem_info,
840 }
841 }
842
843 /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>844 pub fn send_config_changed(&self) -> anyhow::Result<()> {
845 self.conn
846 .lock()
847 .handle_config_change()
848 .context("Could not send config change message")?;
849 Ok(())
850 }
851
852 /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>853 pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
854 let shmem_info = self
855 .shmem_info
856 .lock()
857 .take()
858 .context("could not take shared memory mapper information")?;
859
860 Ok(Box::new(VhostShmemMapper {
861 conn: self.conn.clone(),
862 shmem_info,
863 }))
864 }
865 }
866
867 struct VhostShmemMapper {
868 conn: Arc<Mutex<FrontendClient>>,
869 shmem_info: ShmemInfo,
870 }
871
872 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>873 fn add_mapping(
874 &mut self,
875 source: VmMemorySource,
876 offset: u64,
877 prot: Protection,
878 _cache: MemCacheType,
879 ) -> anyhow::Result<()> {
880 let size = match source {
881 VmMemorySource::Vulkan {
882 descriptor,
883 handle_type,
884 memory_idx,
885 device_uuid,
886 driver_uuid,
887 size,
888 } => {
889 let msg = VhostUserGpuMapMsg::new(
890 self.shmem_info.shmid,
891 offset,
892 size,
893 memory_idx,
894 handle_type,
895 device_uuid,
896 driver_uuid,
897 );
898 self.conn
899 .lock()
900 .gpu_map(&msg, &descriptor)
901 .context("failed to map memory")?;
902 size
903 }
904 VmMemorySource::ExternalMapping { ptr, size } => {
905 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
906 self.conn
907 .lock()
908 .external_map(&msg)
909 .context("failed to map memory")?;
910 size
911 }
912 source => {
913 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
914 // on the aliased `source` above.
915 let (descriptor, fd_offset, size) = match source {
916 VmMemorySource::Descriptor {
917 descriptor,
918 offset,
919 size,
920 } => (descriptor, offset, size),
921 VmMemorySource::SharedMemory(shmem) => {
922 let size = shmem.size();
923 let descriptor = SafeDescriptor::from(shmem);
924 (descriptor, 0, size)
925 }
926 _ => bail!("unsupported source"),
927 };
928 let flags = VhostUserShmemMapMsgFlags::from(prot);
929 let msg = VhostUserShmemMapMsg::new(
930 self.shmem_info.shmid,
931 offset,
932 fd_offset,
933 size,
934 flags,
935 );
936 self.conn
937 .lock()
938 .shmem_map(&msg, &descriptor)
939 .context("failed to map memory")?;
940 size
941 }
942 };
943
944 self.shmem_info.mapped_regions.insert(offset, size);
945 Ok(())
946 }
947
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>948 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
949 let size = self
950 .shmem_info
951 .mapped_regions
952 .remove(&offset)
953 .context("unknown offset")?;
954 let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
955 self.conn
956 .lock()
957 .shmem_unmap(&msg)
958 .context("failed to map memory")
959 .map(|_| ())
960 }
961 }
962
963 pub(crate) struct WorkerState<T, U> {
964 pub(crate) queue_task: TaskHandle<U>,
965 pub(crate) queue: T,
966 }
967
968 /// Errors for device operations
969 #[derive(Debug, ThisError)]
970 pub enum Error {
971 #[error("worker not found when stopping queue")]
972 WorkerNotFound,
973 }
974
975 #[cfg(test)]
976 mod tests {
977 use std::sync::mpsc::channel;
978 use std::sync::Barrier;
979
980 use anyhow::bail;
981 use base::Event;
982 use vmm_vhost::BackendServer;
983 use vmm_vhost::FrontendReq;
984 use zerocopy::FromBytes;
985 use zerocopy::FromZeros;
986 use zerocopy::Immutable;
987 use zerocopy::IntoBytes;
988 use zerocopy::KnownLayout;
989
990 use super::*;
991 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
992 use crate::virtio::DeviceType;
993 use crate::virtio::VirtioDevice;
994
995 #[derive(Clone, Copy, Debug, PartialEq, Eq, FromBytes, Immutable, IntoBytes, KnownLayout)]
996 #[repr(C, packed(4))]
997 struct FakeConfig {
998 x: u32,
999 y: u64,
1000 }
1001
1002 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1003
1004 pub(super) struct FakeBackend {
1005 avail_features: u64,
1006 acked_features: u64,
1007 active_queues: Vec<Option<Queue>>,
1008 allow_backend_req: bool,
1009 backend_conn: Option<Arc<VhostBackendReqConnection>>,
1010 }
1011
1012 #[derive(Deserialize, Serialize)]
1013 struct FakeBackendSnapshot {
1014 data: Vec<u8>,
1015 }
1016
1017 impl FakeBackend {
1018 const MAX_QUEUE_NUM: usize = 16;
1019
new() -> Self1020 pub(super) fn new() -> Self {
1021 let mut active_queues = Vec::new();
1022 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1023 Self {
1024 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1025 acked_features: 0,
1026 active_queues,
1027 allow_backend_req: false,
1028 backend_conn: None,
1029 }
1030 }
1031 }
1032
1033 impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1034 fn max_queue_num(&self) -> usize {
1035 Self::MAX_QUEUE_NUM
1036 }
1037
features(&self) -> u641038 fn features(&self) -> u64 {
1039 self.avail_features
1040 }
1041
ack_features(&mut self, value: u64) -> anyhow::Result<()>1042 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1043 let unrequested_features = value & !self.avail_features;
1044 if unrequested_features != 0 {
1045 bail!(
1046 "invalid protocol features are given: 0x{:x}",
1047 unrequested_features
1048 );
1049 }
1050 self.acked_features |= value;
1051 Ok(())
1052 }
1053
protocol_features(&self) -> VhostUserProtocolFeatures1054 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1055 let mut features =
1056 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1057 if self.allow_backend_req {
1058 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1059 }
1060 features
1061 }
1062
read_config(&self, offset: u64, dst: &mut [u8])1063 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1064 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1065 }
1066
reset(&mut self)1067 fn reset(&mut self) {}
1068
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, ) -> anyhow::Result<()>1069 fn start_queue(
1070 &mut self,
1071 idx: usize,
1072 queue: Queue,
1073 _mem: GuestMemory,
1074 ) -> anyhow::Result<()> {
1075 self.active_queues[idx] = Some(queue);
1076 Ok(())
1077 }
1078
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1079 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1080 Ok(self.active_queues[idx]
1081 .take()
1082 .ok_or(Error::WorkerNotFound)?)
1083 }
1084
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1085 fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1086 self.backend_conn = Some(conn);
1087 }
1088
enter_suspended_state(&mut self) -> anyhow::Result<()>1089 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1090 Ok(())
1091 }
1092
snapshot(&mut self) -> anyhow::Result<AnySnapshot>1093 fn snapshot(&mut self) -> anyhow::Result<AnySnapshot> {
1094 AnySnapshot::to_any(FakeBackendSnapshot {
1095 data: vec![1, 2, 3],
1096 })
1097 .context("failed to serialize snapshot")
1098 }
1099
restore(&mut self, data: AnySnapshot) -> anyhow::Result<()>1100 fn restore(&mut self, data: AnySnapshot) -> anyhow::Result<()> {
1101 let snapshot: FakeBackendSnapshot =
1102 AnySnapshot::from_any(data).context("failed to deserialize snapshot")?;
1103 assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1104 Ok(())
1105 }
1106 }
1107
1108 #[test]
test_vhost_user_lifecycle()1109 fn test_vhost_user_lifecycle() {
1110 test_vhost_user_lifecycle_parameterized(false);
1111 }
1112
1113 #[test]
1114 #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_lifecycle_with_backend_req()1115 fn test_vhost_user_lifecycle_with_backend_req() {
1116 test_vhost_user_lifecycle_parameterized(true);
1117 }
1118
test_vhost_user_lifecycle_parameterized(allow_backend_req: bool)1119 fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1120 const QUEUES_NUM: usize = 2;
1121
1122 let (client_connection, server_connection) =
1123 vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1124
1125 let vmm_bar = Arc::new(Barrier::new(2));
1126 let dev_bar = vmm_bar.clone();
1127
1128 let (ready_tx, ready_rx) = channel();
1129 let (shutdown_tx, shutdown_rx) = channel();
1130
1131 std::thread::spawn(move || {
1132 // VMM side
1133 ready_rx.recv().unwrap(); // Ensure the device is ready.
1134
1135 let mut vmm_device =
1136 VhostUserFrontend::new(DeviceType::Console, 0, client_connection, None, None)
1137 .unwrap();
1138
1139 println!("read_config");
1140 let mut config = FakeConfig::new_zeroed();
1141 vmm_device.read_config(0, config.as_mut_bytes());
1142 // Check if the obtained config data is correct.
1143 assert_eq!(config, FAKE_CONFIG_DATA);
1144
1145 let activate = |vmm_device: &mut VhostUserFrontend| {
1146 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1147 let interrupt = Interrupt::new_for_test_with_msix();
1148
1149 let mut queues = BTreeMap::new();
1150 for idx in 0..QUEUES_NUM {
1151 let mut queue = QueueConfig::new(0x10, 0);
1152 queue.set_ready(true);
1153 let queue = queue
1154 .activate(&mem, Event::new().unwrap(), interrupt.clone())
1155 .expect("QueueConfig::activate");
1156 queues.insert(idx, queue);
1157 }
1158
1159 println!("activate");
1160 vmm_device.activate(mem, interrupt, queues).unwrap();
1161 };
1162
1163 activate(&mut vmm_device);
1164
1165 println!("reset");
1166 let reset_result = vmm_device.reset();
1167 assert!(
1168 reset_result.is_ok(),
1169 "reset failed: {:#}",
1170 reset_result.unwrap_err()
1171 );
1172
1173 activate(&mut vmm_device);
1174
1175 println!("virtio_sleep");
1176 let queues = vmm_device
1177 .virtio_sleep()
1178 .unwrap()
1179 .expect("virtio_sleep unexpectedly returned None");
1180
1181 println!("virtio_snapshot");
1182 let snapshot = vmm_device
1183 .virtio_snapshot()
1184 .expect("virtio_snapshot failed");
1185 println!("virtio_restore");
1186 vmm_device
1187 .virtio_restore(snapshot)
1188 .expect("virtio_restore failed");
1189
1190 println!("virtio_wake");
1191 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1192 let interrupt = Interrupt::new_for_test_with_msix();
1193 vmm_device
1194 .virtio_wake(Some((mem, interrupt, queues)))
1195 .unwrap();
1196
1197 println!("wait for shutdown signal");
1198 shutdown_rx.recv().unwrap();
1199
1200 // The VMM side is supposed to stop before the device side.
1201 println!("drop");
1202 drop(vmm_device);
1203
1204 vmm_bar.wait();
1205 });
1206
1207 // Device side
1208 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1209 handler.as_mut().allow_backend_req = allow_backend_req;
1210
1211 // Notify listener is ready.
1212 ready_tx.send(()).unwrap();
1213
1214 let mut req_handler = BackendServer::new(server_connection, handler);
1215
1216 // VhostUserFrontend::new()
1217 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1218 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1219 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1220 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1221 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1222 if allow_backend_req {
1223 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1224 }
1225
1226 // VhostUserFrontend::read_config()
1227 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1228
1229 // VhostUserFrontend::activate()
1230 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1231 for _ in 0..QUEUES_NUM {
1232 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1233 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1234 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1235 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1236 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1237 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1238 }
1239
1240 // VhostUserFrontend::reset()
1241 for _ in 0..QUEUES_NUM {
1242 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1243 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1244 }
1245
1246 // VhostUserFrontend::activate()
1247 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1248 for _ in 0..QUEUES_NUM {
1249 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1250 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1251 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1252 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1253 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1254 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1255 }
1256
1257 if allow_backend_req {
1258 // Make sure the connection still works even after reset/reactivate.
1259 req_handler
1260 .as_ref()
1261 .as_ref()
1262 .backend_conn
1263 .as_ref()
1264 .expect("backend_conn missing")
1265 .send_config_changed()
1266 .expect("send_config_changed failed");
1267 }
1268
1269 // VhostUserFrontend::virtio_sleep()
1270 for _ in 0..QUEUES_NUM {
1271 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1272 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1273 }
1274
1275 // VhostUserFrontend::virtio_snapshot()
1276 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1277 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1278 // VhostUserFrontend::virtio_restore()
1279 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1280 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1281
1282 // VhostUserFrontend::virtio_wake()
1283 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1284 for _ in 0..QUEUES_NUM {
1285 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1286 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1287 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1288 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1289 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1290 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1291 }
1292
1293 if allow_backend_req {
1294 // Make sure the connection still works even after sleep/wake.
1295 req_handler
1296 .as_ref()
1297 .as_ref()
1298 .backend_conn
1299 .as_ref()
1300 .expect("backend_conn missing")
1301 .send_config_changed()
1302 .expect("send_config_changed failed");
1303 }
1304
1305 // Ask the client to shutdown, then wait to it to finish.
1306 shutdown_tx.send(()).unwrap();
1307 dev_bar.wait();
1308
1309 // Verify recv_header fails with `ClientExit` after the client has disconnected.
1310 match req_handler.recv_header() {
1311 Err(VhostError::ClientExit) => (),
1312 r => panic!("expected Err(ClientExit) but got {:?}", r),
1313 }
1314 }
1315
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1316 fn handle_request<S: vmm_vhost::Backend>(
1317 handler: &mut BackendServer<S>,
1318 expected_message_type: FrontendReq,
1319 ) -> Result<(), VhostError> {
1320 let (hdr, files) = handler.recv_header()?;
1321 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1322 handler.process_message(hdr, files)
1323 }
1324 }
1325