1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //! /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //! /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //! let backend = MyBackend { /* initialize fields */ };
28 //! let handler = DeviceRequestHandler::new(backend);
29 //! let socket = std::path::Path("/path/to/socket");
30 //! let ex = cros_async::Executor::new()?;
31 //!
32 //! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //! eprintln!("error happened: {}", e);
34 //! }
35 //! Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46
47 pub(super) mod sys;
48
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::warn;
63 use base::Event;
64 use base::FromRawDescriptor;
65 use base::IntoRawDescriptor;
66 use base::Protection;
67 use base::SafeDescriptor;
68 use base::SharedMemory;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use sync::Mutex;
74 use thiserror::Error as ThisError;
75 use vm_control::VmMemorySource;
76 use vm_memory::GuestAddress;
77 use vm_memory::GuestMemory;
78 use vm_memory::MemoryRegion;
79 use vmm_vhost::message::VhostSharedMemoryRegion;
80 use vmm_vhost::message::VhostUserConfigFlags;
81 use vmm_vhost::message::VhostUserExternalMapMsg;
82 use vmm_vhost::message::VhostUserGpuMapMsg;
83 use vmm_vhost::message::VhostUserInflight;
84 use vmm_vhost::message::VhostUserMemoryRegion;
85 use vmm_vhost::message::VhostUserProtocolFeatures;
86 use vmm_vhost::message::VhostUserShmemMapMsg;
87 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
88 use vmm_vhost::message::VhostUserShmemUnmapMsg;
89 use vmm_vhost::message::VhostUserSingleMemoryRegion;
90 use vmm_vhost::message::VhostUserVringAddrFlags;
91 use vmm_vhost::message::VhostUserVringState;
92 use vmm_vhost::BackendReq;
93 use vmm_vhost::Connection;
94 use vmm_vhost::Error as VhostError;
95 use vmm_vhost::Frontend;
96 use vmm_vhost::FrontendClient;
97 use vmm_vhost::Result as VhostResult;
98 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
99
100 use crate::virtio::Interrupt;
101 use crate::virtio::Queue;
102 use crate::virtio::QueueConfig;
103 use crate::virtio::SharedMemoryMapper;
104 use crate::virtio::SharedMemoryRegion;
105
106 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
107 /// used to translate messages from the vmm to guest offsets.
108 #[derive(Default)]
109 pub struct MappingInfo {
110 pub vmm_addr: u64,
111 pub guest_phys: u64,
112 pub size: u64,
113 }
114
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>115 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
116 for map in maps {
117 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
118 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
119 }
120 }
121 Err(VhostError::InvalidMessage)
122 }
123
124 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
125 ///
126 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
127 /// is designed to follow crosvm conventions for implementing devices.
128 pub trait VhostUserDevice {
129 /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize130 fn max_queue_num(&self) -> usize;
131
132 /// The set of feature bits that this backend supports.
features(&self) -> u64133 fn features(&self) -> u64;
134
135 /// Acknowledges that this set of features should be enabled.
ack_features(&mut self, value: u64) -> anyhow::Result<()>136 fn ack_features(&mut self, value: u64) -> anyhow::Result<()>;
137
138 /// Returns the set of enabled features.
acked_features(&self) -> u64139 fn acked_features(&self) -> u64;
140
141 /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures142 fn protocol_features(&self) -> VhostUserProtocolFeatures;
143
144 /// Acknowledges that this set of protocol features should be enabled.
ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>145 fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>;
146
147 /// Returns the set of enabled protocol features.
acked_protocol_features(&self) -> u64148 fn acked_protocol_features(&self) -> u64;
149
150 /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])151 fn read_config(&self, offset: u64, dst: &mut [u8]);
152
153 /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])154 fn write_config(&self, _offset: u64, _data: &[u8]) {}
155
156 /// Indicates that the backend should start processing requests for virtio queue number `idx`.
157 /// This method must not block the current thread so device backends should either spawn an
158 /// async task or another thread to handle messages from the Queue.
start_queue( &mut self, idx: usize, queue: Queue, mem: GuestMemory, doorbell: Interrupt, ) -> anyhow::Result<()>159 fn start_queue(
160 &mut self,
161 idx: usize,
162 queue: Queue,
163 mem: GuestMemory,
164 doorbell: Interrupt,
165 ) -> anyhow::Result<()>;
166
167 /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
168 /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
169 /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>170 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
171
172 /// Resets the vhost-user backend.
reset(&mut self)173 fn reset(&mut self);
174
175 /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>176 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
177 None
178 }
179
180 /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
181 /// handling.
182 ///
183 /// The backend is given an `Arc` instead of full ownership so that the framework can also use
184 /// the connection.
185 ///
186 /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
187 /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)188 fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {
189 error!("set_backend_req_connection is not implemented");
190 }
191
192 /// Used to stop non queue workers that `VhostUserDevice::stop_queue` can't stop. May or may
193 /// not also stop all queue workers.
stop_non_queue_workers(&mut self) -> anyhow::Result<()>194 fn stop_non_queue_workers(&mut self) -> anyhow::Result<()> {
195 error!("sleep not implemented for vhost user device");
196 // TODO(rizhang): Return error once basic devices support this.
197 Ok(())
198 }
199
200 /// Snapshot device and return serialized bytes.
snapshot(&self) -> anyhow::Result<Vec<u8>>201 fn snapshot(&self) -> anyhow::Result<Vec<u8>> {
202 error!("snapshot not implemented for vhost user device");
203 // TODO(rizhang): Return error once basic devices support this.
204 Ok(Vec::new())
205 }
206
restore(&mut self, _data: Vec<u8>) -> anyhow::Result<()>207 fn restore(&mut self, _data: Vec<u8>) -> anyhow::Result<()> {
208 error!("restore not implemented for vhost user device");
209 // TODO(rizhang): Return error once basic devices support this.
210 Ok(())
211 }
212 }
213
214 /// A virtio ring entry.
215 struct Vring {
216 // The queue config. This doesn't get mutated by the queue workers.
217 queue: QueueConfig,
218 doorbell: Option<Interrupt>,
219 enabled: bool,
220 // Active queue that is only `Some` when the device is sleeping.
221 paused_queue: Option<Queue>,
222 }
223
224 #[derive(Serialize, Deserialize)]
225 struct VringSnapshot {
226 // Snapshot of queue config.
227 queue: serde_json::Value,
228 // Snapshot of the activated queue state.
229 paused_queue: Option<serde_json::Value>,
230 enabled: bool,
231 }
232
233 impl Vring {
new(max_size: u16, features: u64) -> Self234 fn new(max_size: u16, features: u64) -> Self {
235 Self {
236 queue: QueueConfig::new(max_size, features),
237 doorbell: None,
238 enabled: false,
239 paused_queue: None,
240 }
241 }
242
reset(&mut self)243 fn reset(&mut self) {
244 self.queue.reset();
245 self.doorbell = None;
246 self.enabled = false;
247 self.paused_queue = None;
248 }
249
snapshot(&self) -> anyhow::Result<VringSnapshot>250 fn snapshot(&self) -> anyhow::Result<VringSnapshot> {
251 Ok(VringSnapshot {
252 queue: self.queue.snapshot()?,
253 enabled: self.enabled,
254 paused_queue: self
255 .paused_queue
256 .as_ref()
257 .map(Queue::snapshot)
258 .transpose()?,
259 })
260 }
261
restore( &mut self, vring_snapshot: VringSnapshot, mem: &GuestMemory, event: Option<Event>, ) -> anyhow::Result<()>262 fn restore(
263 &mut self,
264 vring_snapshot: VringSnapshot,
265 mem: &GuestMemory,
266 event: Option<Event>,
267 ) -> anyhow::Result<()> {
268 self.queue.restore(vring_snapshot.queue)?;
269 self.enabled = vring_snapshot.enabled;
270 self.paused_queue = vring_snapshot
271 .paused_queue
272 .map(|value| {
273 Queue::restore(
274 &self.queue,
275 value,
276 mem,
277 event.context("missing queue event")?,
278 )
279 })
280 .transpose()?;
281 Ok(())
282 }
283 }
284
285 /// Ops for running vhost-user over a stream (i.e. regular protocol).
286 pub(super) struct VhostUserRegularOps;
287
288 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>289 pub fn set_mem_table(
290 contexts: &[VhostUserMemoryRegion],
291 files: Vec<File>,
292 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
293 if files.len() != contexts.len() {
294 return Err(VhostError::InvalidParam);
295 }
296
297 let mut regions = Vec::with_capacity(files.len());
298 for (region, file) in contexts.iter().zip(files.into_iter()) {
299 let region = MemoryRegion::new_from_shm(
300 region.memory_size,
301 GuestAddress(region.guest_phys_addr),
302 region.mmap_offset,
303 Arc::new(
304 SharedMemory::from_safe_descriptor(
305 SafeDescriptor::from(file),
306 region.memory_size,
307 )
308 .unwrap(),
309 ),
310 )
311 .map_err(|e| {
312 error!("failed to create a memory region: {}", e);
313 VhostError::InvalidOperation
314 })?;
315 regions.push(region);
316 }
317 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
318 error!("failed to create guest memory: {}", e);
319 VhostError::InvalidOperation
320 })?;
321
322 let vmm_maps = contexts
323 .iter()
324 .map(|region| MappingInfo {
325 vmm_addr: region.user_addr,
326 guest_phys: region.guest_phys_addr,
327 size: region.memory_size,
328 })
329 .collect();
330 Ok((guest_mem, vmm_maps))
331 }
332
set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event>333 pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
334 let file = file.ok_or(VhostError::InvalidParam)?;
335 // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
336 // values via `next_val()` later.
337 // This is only required (and can only be done) on Unix platforms.
338 #[cfg(any(target_os = "android", target_os = "linux"))]
339 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
340 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
341 return Err(VhostError::InvalidParam);
342 }
343 Ok(Event::from(SafeDescriptor::from(file)))
344 }
345
set_vring_call( _index: u8, file: Option<File>, signal_config_change_fn: Box<dyn Fn() + Send + Sync>, ) -> VhostResult<Interrupt>346 pub fn set_vring_call(
347 _index: u8,
348 file: Option<File>,
349 signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
350 ) -> VhostResult<Interrupt> {
351 let file = file.ok_or(VhostError::InvalidParam)?;
352 Ok(Interrupt::new_vhost_user(
353 Event::from(SafeDescriptor::from(file)),
354 signal_config_change_fn,
355 ))
356 }
357 }
358
359 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
360 pub struct DeviceRequestHandler<T: VhostUserDevice> {
361 vrings: Vec<Vring>,
362 owned: bool,
363 vmm_maps: Option<Vec<MappingInfo>>,
364 mem: Option<GuestMemory>,
365 backend: T,
366 backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
367 }
368
369 #[derive(Serialize, Deserialize)]
370 pub struct DeviceRequestHandlerSnapshot {
371 vrings: Vec<VringSnapshot>,
372 backend: Vec<u8>,
373 }
374
375 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
376 /// Creates a vhost-user handler instance for `backend`.
new(backend: T) -> Self377 pub(crate) fn new(backend: T) -> Self {
378 let mut vrings = Vec::with_capacity(backend.max_queue_num());
379 for _ in 0..backend.max_queue_num() {
380 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
381 }
382
383 DeviceRequestHandler {
384 vrings,
385 owned: false,
386 vmm_maps: None,
387 mem: None,
388 backend,
389 backend_req_connection: Arc::new(Mutex::new(
390 VhostBackendReqConnectionState::NoConnection,
391 )),
392 }
393 }
394 }
395
396 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T397 fn as_ref(&self) -> &T {
398 &self.backend
399 }
400 }
401
402 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T403 fn as_mut(&mut self) -> &mut T {
404 &mut self.backend
405 }
406 }
407
408 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>409 fn set_owner(&mut self) -> VhostResult<()> {
410 if self.owned {
411 return Err(VhostError::InvalidOperation);
412 }
413 self.owned = true;
414 Ok(())
415 }
416
reset_owner(&mut self) -> VhostResult<()>417 fn reset_owner(&mut self) -> VhostResult<()> {
418 self.owned = false;
419 self.backend.reset();
420 Ok(())
421 }
422
get_features(&mut self) -> VhostResult<u64>423 fn get_features(&mut self) -> VhostResult<u64> {
424 let features = self.backend.features();
425 Ok(features)
426 }
427
set_features(&mut self, features: u64) -> VhostResult<()>428 fn set_features(&mut self, features: u64) -> VhostResult<()> {
429 if !self.owned {
430 return Err(VhostError::InvalidOperation);
431 }
432
433 if (features & !(self.backend.features())) != 0 {
434 return Err(VhostError::InvalidParam);
435 }
436
437 if let Err(e) = self.backend.ack_features(features) {
438 error!("failed to acknowledge features 0x{:x}: {}", features, e);
439 return Err(VhostError::InvalidOperation);
440 }
441
442 // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
443 // enabled state.
444 // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
445 // disabled state.
446 // Client must not pass data to/from the backend until ring is enabled by
447 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
448 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
449 let acked_features = self.backend.acked_features();
450 let vring_enabled = acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
451 for v in &mut self.vrings {
452 v.enabled = vring_enabled;
453 }
454
455 Ok(())
456 }
457
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>458 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
459 Ok(self.backend.protocol_features())
460 }
461
set_protocol_features(&mut self, features: u64) -> VhostResult<()>462 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
463 if let Err(e) = self.backend.ack_protocol_features(features) {
464 error!("failed to set protocol features 0x{:x}: {}", features, e);
465 return Err(VhostError::InvalidOperation);
466 }
467 Ok(())
468 }
469
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>470 fn set_mem_table(
471 &mut self,
472 contexts: &[VhostUserMemoryRegion],
473 files: Vec<File>,
474 ) -> VhostResult<()> {
475 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
476 self.mem = Some(guest_mem);
477 self.vmm_maps = Some(vmm_maps);
478 Ok(())
479 }
480
get_queue_num(&mut self) -> VhostResult<u64>481 fn get_queue_num(&mut self) -> VhostResult<u64> {
482 Ok(self.vrings.len() as u64)
483 }
484
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>485 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
486 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
487 return Err(VhostError::InvalidParam);
488 }
489 self.vrings[index as usize].queue.set_size(num as u16);
490
491 Ok(())
492 }
493
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>494 fn set_vring_addr(
495 &mut self,
496 index: u32,
497 _flags: VhostUserVringAddrFlags,
498 descriptor: u64,
499 used: u64,
500 available: u64,
501 _log: u64,
502 ) -> VhostResult<()> {
503 if index as usize >= self.vrings.len() {
504 return Err(VhostError::InvalidParam);
505 }
506
507 let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
508 let vring = &mut self.vrings[index as usize];
509 vring
510 .queue
511 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
512 vring
513 .queue
514 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
515 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
516
517 Ok(())
518 }
519
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>520 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
521 if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
522 return Err(VhostError::InvalidParam);
523 }
524
525 let vring = &mut self.vrings[index as usize];
526 vring.queue.set_next_avail(Wrapping(base as u16));
527 vring.queue.set_next_used(Wrapping(base as u16));
528
529 Ok(())
530 }
531
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>532 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
533 let vring = self
534 .vrings
535 .get_mut(index as usize)
536 .ok_or(VhostError::InvalidParam)?;
537
538 // Quotation from vhost-user spec:
539 // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
540 // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
541 // means it is started and should be stopped.
542 if vring.queue.ready() {
543 if let Err(e) = self.backend.stop_queue(index as usize) {
544 error!("Failed to stop queue in get_vring_base: {:#}", e);
545 }
546
547 vring.reset();
548 }
549
550 Ok(VhostUserVringState::new(
551 index,
552 vring.queue.next_avail().0 as u32,
553 ))
554 }
555
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>556 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
557 if index as usize >= self.vrings.len() {
558 return Err(VhostError::InvalidParam);
559 }
560
561 let vring = &mut self.vrings[index as usize];
562 if vring.queue.ready() {
563 error!("kick fd cannot replaced after queue is started");
564 return Err(VhostError::InvalidOperation);
565 }
566
567 let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
568
569 // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
570 vring.queue.ack_features(self.backend.acked_features());
571 vring.queue.set_ready(true);
572
573 let mem = self
574 .mem
575 .as_ref()
576 .cloned()
577 .ok_or(VhostError::InvalidOperation)?;
578
579 let queue = match vring.queue.activate(&mem, kick_evt) {
580 Ok(queue) => queue,
581 Err(e) => {
582 error!("failed to activate vring: {:#}", e);
583 return Err(VhostError::BackendInternalError);
584 }
585 };
586
587 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
588
589 if let Err(e) = self
590 .backend
591 .start_queue(index as usize, queue, mem, doorbell)
592 {
593 error!("Failed to start queue {}: {}", index, e);
594 return Err(VhostError::BackendInternalError);
595 }
596
597 Ok(())
598 }
599
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>600 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
601 if index as usize >= self.vrings.len() {
602 return Err(VhostError::InvalidParam);
603 }
604
605 let backend_req_conn = self.backend_req_connection.clone();
606 let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
607 VhostBackendReqConnectionState::Connected(frontend) => {
608 if let Err(e) = frontend.send_config_changed() {
609 error!("Failed to notify config change: {:#}", e);
610 }
611 }
612 VhostBackendReqConnectionState::NoConnection => {
613 error!("No Backend request connection found");
614 }
615 });
616
617 let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
618 self.vrings[index as usize].doorbell = Some(doorbell);
619 Ok(())
620 }
621
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>622 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
623 // TODO
624 Ok(())
625 }
626
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>627 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
628 if index as usize >= self.vrings.len() {
629 return Err(VhostError::InvalidParam);
630 }
631
632 // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
633 // has been negotiated.
634 if self.backend.acked_features() & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
635 return Err(VhostError::InvalidOperation);
636 }
637
638 // Backend must not pass data to/from the ring until ring is enabled by
639 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
640 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
641 self.vrings[index as usize].enabled = enable;
642
643 Ok(())
644 }
645
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>646 fn get_config(
647 &mut self,
648 offset: u32,
649 size: u32,
650 _flags: VhostUserConfigFlags,
651 ) -> VhostResult<Vec<u8>> {
652 let mut data = vec![0; size as usize];
653 self.backend.read_config(u64::from(offset), &mut data);
654 Ok(data)
655 }
656
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>657 fn set_config(
658 &mut self,
659 offset: u32,
660 buf: &[u8],
661 _flags: VhostUserConfigFlags,
662 ) -> VhostResult<()> {
663 self.backend.write_config(u64::from(offset), buf);
664 Ok(())
665 }
666
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)667 fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
668 let conn = Arc::new(VhostBackendReqConnection::new(
669 FrontendClient::new(ep),
670 self.backend.get_shared_memory_region().map(|r| r.id),
671 ));
672
673 {
674 let mut backend_req_conn = self.backend_req_connection.lock();
675 if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
676 warn!("Backend Request Connection already established. Overwriting");
677 }
678 *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
679 }
680
681 self.backend.set_backend_req_connection(conn);
682 }
683
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>684 fn get_inflight_fd(
685 &mut self,
686 _inflight: &VhostUserInflight,
687 ) -> VhostResult<(VhostUserInflight, File)> {
688 unimplemented!("get_inflight_fd");
689 }
690
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>691 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
692 unimplemented!("set_inflight_fd");
693 }
694
get_max_mem_slots(&mut self) -> VhostResult<u64>695 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
696 //TODO
697 Ok(0)
698 }
699
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>700 fn add_mem_region(
701 &mut self,
702 _region: &VhostUserSingleMemoryRegion,
703 _fd: File,
704 ) -> VhostResult<()> {
705 //TODO
706 Ok(())
707 }
708
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>709 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
710 //TODO
711 Ok(())
712 }
713
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>714 fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
715 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
716 vec![VhostSharedMemoryRegion::new(r.id, r.length)]
717 } else {
718 Vec::new()
719 })
720 }
721
sleep(&mut self) -> VhostResult<()>722 fn sleep(&mut self) -> VhostResult<()> {
723 for (index, vring) in self
724 .vrings
725 .iter_mut()
726 .enumerate()
727 .filter(|(_index, vring)| vring.queue.ready())
728 {
729 match self.backend.stop_queue(index) {
730 Ok(queue) => vring.paused_queue = Some(queue),
731 Err(e) => {
732 error!("failed to stop queue index {}: {:#}", index, e);
733 return Err(VhostError::StopQueueError(e));
734 }
735 }
736 }
737 self.backend
738 .stop_non_queue_workers()
739 .map_err(VhostError::SleepError)
740 }
741
wake(&mut self) -> VhostResult<()>742 fn wake(&mut self) -> VhostResult<()> {
743 for (index, vring) in self.vrings.iter_mut().enumerate() {
744 if let Some(queue) = vring.paused_queue.take() {
745 let mem = self.mem.clone().ok_or(VhostError::BackendInternalError)?;
746 let doorbell = vring.doorbell.clone().expect("Failed to clone doorbell");
747
748 if let Err(e) = self.backend.start_queue(index, queue, mem, doorbell) {
749 error!("Failed to start queue {}: {}", index, e);
750 return Err(VhostError::BackendInternalError);
751 }
752 }
753 }
754 Ok(())
755 }
756
snapshot(&mut self) -> VhostResult<Vec<u8>>757 fn snapshot(&mut self) -> VhostResult<Vec<u8>> {
758 match serde_json::to_vec(&DeviceRequestHandlerSnapshot {
759 vrings: self
760 .vrings
761 .iter()
762 .map(|vring| vring.snapshot())
763 .collect::<anyhow::Result<Vec<VringSnapshot>>>()
764 .map_err(VhostError::SnapshotError)?,
765 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
766 }) {
767 Ok(serialized_json) => Ok(serialized_json),
768 Err(e) => {
769 error!("Failed to serialize DeviceRequestHandlerSnapshot: {}", e);
770 Err(VhostError::SerializationFailed)
771 }
772 }
773 }
774
restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> VhostResult<()>775 fn restore(&mut self, data_bytes: &[u8], queue_evts: Vec<File>) -> VhostResult<()> {
776 let device_request_handler_snapshot: DeviceRequestHandlerSnapshot =
777 serde_json::from_slice(data_bytes).map_err(|e| {
778 error!("Failed to deserialize DeviceRequestHandlerSnapshot: {}", e);
779 VhostError::DeserializationFailed
780 })?;
781
782 let mem = self.mem.as_ref().ok_or(VhostError::InvalidOperation)?;
783
784 let snapshotted_vrings = device_request_handler_snapshot.vrings;
785 assert_eq!(snapshotted_vrings.len(), self.vrings.len());
786
787 let mut queue_evts_iter = if queue_evts.is_empty() {
788 None
789 } else {
790 Some(queue_evts.into_iter())
791 };
792
793 for (index, (vring, snapshotted_vring)) in self
794 .vrings
795 .iter_mut()
796 .zip(snapshotted_vrings.into_iter())
797 .enumerate()
798 {
799 let queue_evt = if let Some(queue_evts_iter) = &mut queue_evts_iter {
800 // TODO(b/288596005): It is assumed that the index of `queue_evts` should map to the
801 // index of `self.vrings`. However, this assumption may break in the future, so a
802 // Map of indexes to queue_evt should be used to support sparse activated queues.
803 let queue_evt_file = queue_evts_iter
804 .next()
805 .ok_or(VhostError::VringIndexNotFound(index))?;
806 Some(VhostUserRegularOps::set_vring_kick(
807 index as u8,
808 Some(queue_evt_file),
809 )?)
810 } else {
811 None
812 };
813
814 vring
815 .restore(snapshotted_vring, mem, queue_evt)
816 .map_err(VhostError::RestoreError)?;
817 }
818
819 self.backend
820 .restore(device_request_handler_snapshot.backend)
821 .map_err(VhostError::RestoreError)?;
822
823 Ok(())
824 }
825 }
826
827 /// Indicates the state of backend request connection
828 pub enum VhostBackendReqConnectionState {
829 /// A backend request connection (`VhostBackendReqConnection`) is established
830 Connected(Arc<VhostBackendReqConnection>),
831 /// No backend request connection has been established yet
832 NoConnection,
833 }
834
835 /// Keeps track of Vhost user backend request connection.
836 pub struct VhostBackendReqConnection {
837 conn: Arc<Mutex<FrontendClient>>,
838 shmem_info: Mutex<Option<ShmemInfo>>,
839 }
840
841 #[derive(Clone)]
842 struct ShmemInfo {
843 shmid: u8,
844 mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
845 }
846
847 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self848 pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
849 let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
850 shmid,
851 mapped_regions: BTreeMap::new(),
852 }));
853 Self {
854 conn: Arc::new(Mutex::new(conn)),
855 shmem_info,
856 }
857 }
858
859 /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>860 pub fn send_config_changed(&self) -> anyhow::Result<()> {
861 self.conn
862 .lock()
863 .handle_config_change()
864 .context("Could not send config change message")?;
865 Ok(())
866 }
867
868 /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>869 pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
870 let shmem_info = self
871 .shmem_info
872 .lock()
873 .take()
874 .context("could not take shared memory mapper information")?;
875
876 Ok(Box::new(VhostShmemMapper {
877 conn: self.conn.clone(),
878 shmem_info,
879 }))
880 }
881 }
882
883 struct VhostShmemMapper {
884 conn: Arc<Mutex<FrontendClient>>,
885 shmem_info: ShmemInfo,
886 }
887
888 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>889 fn add_mapping(
890 &mut self,
891 source: VmMemorySource,
892 offset: u64,
893 prot: Protection,
894 _cache: MemCacheType,
895 ) -> anyhow::Result<()> {
896 let size = match source {
897 VmMemorySource::Vulkan {
898 descriptor,
899 handle_type,
900 memory_idx,
901 device_uuid,
902 driver_uuid,
903 size,
904 } => {
905 let msg = VhostUserGpuMapMsg::new(
906 self.shmem_info.shmid,
907 offset,
908 size,
909 memory_idx,
910 handle_type,
911 device_uuid,
912 driver_uuid,
913 );
914 self.conn
915 .lock()
916 .gpu_map(&msg, &descriptor)
917 .context("failed to map memory")?;
918 size
919 }
920 VmMemorySource::ExternalMapping { ptr, size } => {
921 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
922 self.conn
923 .lock()
924 .external_map(&msg)
925 .context("failed to map memory")?;
926 size
927 }
928 source => {
929 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
930 // on the aliased `source` above.
931 let (descriptor, fd_offset, size) = match source {
932 VmMemorySource::Descriptor {
933 descriptor,
934 offset,
935 size,
936 } => (descriptor, offset, size),
937 VmMemorySource::SharedMemory(shmem) => {
938 let size = shmem.size();
939 let descriptor =
940 // SAFETY:
941 // Safe because we own shmem.
942 unsafe {
943 SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor())
944 };
945 (descriptor, 0, size)
946 }
947 _ => bail!("unsupported source"),
948 };
949 let flags = VhostUserShmemMapMsgFlags::from(prot);
950 let msg = VhostUserShmemMapMsg::new(
951 self.shmem_info.shmid,
952 offset,
953 fd_offset,
954 size,
955 flags,
956 );
957 self.conn
958 .lock()
959 .shmem_map(&msg, &descriptor)
960 .context("failed to map memory")?;
961 size
962 }
963 };
964
965 self.shmem_info.mapped_regions.insert(offset, size);
966 Ok(())
967 }
968
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>969 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
970 let size = self
971 .shmem_info
972 .mapped_regions
973 .remove(&offset)
974 .context("unknown offset")?;
975 let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
976 self.conn
977 .lock()
978 .shmem_unmap(&msg)
979 .context("failed to map memory")
980 .map(|_| ())
981 }
982 }
983
984 pub(crate) struct WorkerState<T, U> {
985 pub(crate) queue_task: TaskHandle<U>,
986 pub(crate) queue: T,
987 }
988
989 /// Errors for device operations
990 #[derive(Debug, ThisError)]
991 pub enum Error {
992 #[error("worker not found when stopping queue")]
993 WorkerNotFound,
994 }
995
996 #[cfg(test)]
997 mod tests {
998 use std::sync::mpsc::channel;
999 use std::sync::Barrier;
1000
1001 use anyhow::anyhow;
1002 use anyhow::bail;
1003 use base::Event;
1004 use vmm_vhost::BackendServer;
1005 use vmm_vhost::FrontendReq;
1006 use zerocopy::AsBytes;
1007 use zerocopy::FromBytes;
1008 use zerocopy::FromZeroes;
1009
1010 use super::sys::test_helpers;
1011 use super::*;
1012 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
1013 use crate::virtio::DeviceType;
1014 use crate::virtio::VirtioDevice;
1015
1016 #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
1017 #[repr(C, packed(4))]
1018 struct FakeConfig {
1019 x: u32,
1020 y: u64,
1021 }
1022
1023 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
1024
1025 pub(super) struct FakeBackend {
1026 avail_features: u64,
1027 acked_features: u64,
1028 acked_protocol_features: VhostUserProtocolFeatures,
1029 active_queues: Vec<Option<Queue>>,
1030 allow_backend_req: bool,
1031 backend_conn: Option<Arc<VhostBackendReqConnection>>,
1032 }
1033
1034 impl FakeBackend {
1035 const MAX_QUEUE_NUM: usize = 16;
1036
new() -> Self1037 pub(super) fn new() -> Self {
1038 let mut active_queues = Vec::new();
1039 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1040 Self {
1041 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1042 acked_features: 0,
1043 acked_protocol_features: VhostUserProtocolFeatures::empty(),
1044 active_queues,
1045 allow_backend_req: false,
1046 backend_conn: None,
1047 }
1048 }
1049 }
1050
1051 impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1052 fn max_queue_num(&self) -> usize {
1053 Self::MAX_QUEUE_NUM
1054 }
1055
features(&self) -> u641056 fn features(&self) -> u64 {
1057 self.avail_features
1058 }
1059
ack_features(&mut self, value: u64) -> anyhow::Result<()>1060 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1061 let unrequested_features = value & !self.avail_features;
1062 if unrequested_features != 0 {
1063 bail!(
1064 "invalid protocol features are given: 0x{:x}",
1065 unrequested_features
1066 );
1067 }
1068 self.acked_features |= value;
1069 Ok(())
1070 }
1071
acked_features(&self) -> u641072 fn acked_features(&self) -> u64 {
1073 self.acked_features
1074 }
1075
protocol_features(&self) -> VhostUserProtocolFeatures1076 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1077 let mut features = VhostUserProtocolFeatures::CONFIG;
1078 if self.allow_backend_req {
1079 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1080 }
1081 features
1082 }
1083
ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()>1084 fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> {
1085 let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
1086 "invalid protocol features are given: 0x{:x}",
1087 features
1088 ))?;
1089 let supported = self.protocol_features();
1090 self.acked_protocol_features = features & supported;
1091 Ok(())
1092 }
1093
acked_protocol_features(&self) -> u641094 fn acked_protocol_features(&self) -> u64 {
1095 self.acked_protocol_features.bits()
1096 }
1097
read_config(&self, offset: u64, dst: &mut [u8])1098 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1099 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1100 }
1101
reset(&mut self)1102 fn reset(&mut self) {}
1103
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, _doorbell: Interrupt, ) -> anyhow::Result<()>1104 fn start_queue(
1105 &mut self,
1106 idx: usize,
1107 queue: Queue,
1108 _mem: GuestMemory,
1109 _doorbell: Interrupt,
1110 ) -> anyhow::Result<()> {
1111 self.active_queues[idx] = Some(queue);
1112 Ok(())
1113 }
1114
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1115 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1116 Ok(self.active_queues[idx]
1117 .take()
1118 .ok_or(Error::WorkerNotFound)?)
1119 }
1120
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1121 fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1122 self.backend_conn = Some(conn);
1123 }
1124 }
1125
1126 #[test]
test_vhost_user_activate()1127 fn test_vhost_user_activate() {
1128 test_vhost_user_activate_parameterized(false);
1129 }
1130
1131 #[test]
1132 #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_activate_with_backend_req()1133 fn test_vhost_user_activate_with_backend_req() {
1134 test_vhost_user_activate_parameterized(true);
1135 }
1136
test_vhost_user_activate_parameterized(allow_backend_req: bool)1137 fn test_vhost_user_activate_parameterized(allow_backend_req: bool) {
1138 const QUEUES_NUM: usize = 2;
1139
1140 let (dev, vmm) = test_helpers::setup();
1141
1142 let vmm_bar = Arc::new(Barrier::new(2));
1143 let dev_bar = vmm_bar.clone();
1144
1145 let (ready_tx, ready_rx) = channel();
1146 let (shutdown_tx, shutdown_rx) = channel();
1147
1148 std::thread::spawn(move || {
1149 // VMM side
1150 ready_rx.recv().unwrap(); // Ensure the device is ready.
1151
1152 let connection = test_helpers::connect(vmm);
1153
1154 let mut vmm_device =
1155 VhostUserFrontend::new(DeviceType::Console, 0, connection, None, None).unwrap();
1156
1157 println!("read_config");
1158 let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1159 vmm_device.read_config(0, &mut buf);
1160 // Check if the obtained config data is correct.
1161 let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1162 assert_eq!(config, FAKE_CONFIG_DATA);
1163
1164 let activate = |vmm_device: &mut VhostUserFrontend| {
1165 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1166 let interrupt = Interrupt::new_for_test_with_msix();
1167
1168 let mut queues = BTreeMap::new();
1169 for idx in 0..QUEUES_NUM {
1170 let mut queue = QueueConfig::new(0x10, 0);
1171 queue.set_ready(true);
1172 let queue = queue
1173 .activate(&mem, Event::new().unwrap())
1174 .expect("QueueConfig::activate");
1175 queues.insert(idx, queue);
1176 }
1177
1178 println!("activate");
1179 vmm_device
1180 .activate(mem.clone(), interrupt.clone(), queues)
1181 .unwrap();
1182 };
1183
1184 activate(&mut vmm_device);
1185
1186 println!("reset");
1187 let reset_result = vmm_device.reset();
1188 assert!(
1189 reset_result.is_ok(),
1190 "reset failed: {:#}",
1191 reset_result.unwrap_err()
1192 );
1193
1194 activate(&mut vmm_device);
1195
1196 println!("virtio_sleep");
1197 vmm_device.virtio_sleep().unwrap();
1198
1199 println!("virtio_wake");
1200 vmm_device.virtio_wake(None).unwrap();
1201
1202 println!("wait for shutdown signal");
1203 shutdown_rx.recv().unwrap();
1204
1205 // The VMM side is supposed to stop before the device side.
1206 println!("drop");
1207 drop(vmm_device);
1208
1209 vmm_bar.wait();
1210 });
1211
1212 // Device side
1213 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1214 handler.as_mut().allow_backend_req = allow_backend_req;
1215
1216 // Notify listener is ready.
1217 ready_tx.send(()).unwrap();
1218
1219 let mut req_handler = test_helpers::listen(dev, handler);
1220
1221 // VhostUserFrontend::new()
1222 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1223 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1224 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1225 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1226 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1227 if allow_backend_req {
1228 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1229 }
1230
1231 // VhostUserFrontend::read_config()
1232 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1233
1234 // VhostUserFrontend::activate()
1235 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1236 for _ in 0..QUEUES_NUM {
1237 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1238 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1239 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1240 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1241 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1242 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1243 }
1244
1245 // VhostUserFrontend::reset()
1246 for _ in 0..QUEUES_NUM {
1247 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1248 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1249 }
1250
1251 // VhostUserFrontend::activate()
1252 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1253 for _ in 0..QUEUES_NUM {
1254 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1255 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1256 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1257 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1258 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1259 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1260 }
1261
1262 if allow_backend_req {
1263 // Make sure the connection still works even after reset/reactivate.
1264 req_handler
1265 .as_ref()
1266 .as_ref()
1267 .backend_conn
1268 .as_ref()
1269 .expect("backend_conn missing")
1270 .send_config_changed()
1271 .expect("send_config_changed failed");
1272 }
1273
1274 // VhostUserFrontend::virtio_sleep()
1275 handle_request(&mut req_handler, FrontendReq::SLEEP).unwrap();
1276
1277 // VhostUserFrontend::virtio_wake()
1278 handle_request(&mut req_handler, FrontendReq::WAKE).unwrap();
1279
1280 if allow_backend_req {
1281 // Make sure the connection still works even after sleep/wake.
1282 req_handler
1283 .as_ref()
1284 .as_ref()
1285 .backend_conn
1286 .as_ref()
1287 .expect("backend_conn missing")
1288 .send_config_changed()
1289 .expect("send_config_changed failed");
1290 }
1291
1292 // Ask the client to shutdown, then wait to it to finish.
1293 shutdown_tx.send(()).unwrap();
1294 dev_bar.wait();
1295
1296 // Verify recv_header fails with `ClientExit` after the client has disconnected.
1297 match req_handler.recv_header() {
1298 Err(VhostError::ClientExit) => (),
1299 r => panic!("expected Err(ClientExit) but got {:?}", r),
1300 }
1301 }
1302
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1303 fn handle_request<S: vmm_vhost::Backend>(
1304 handler: &mut BackendServer<S>,
1305 expected_message_type: FrontendReq,
1306 ) -> Result<(), VhostError> {
1307 let (hdr, files) = handler.recv_header()?;
1308 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1309 handler.process_message(hdr, files)
1310 }
1311 }
1312