1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserBackend` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserBackend` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //! /* fields */
20 //! }
21 //!
22 //! impl VhostUserBackend for MyBackend {
23 //! /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //! let backend = MyBackend { /* initialize fields */ };
28 //! let handler = DeviceRequestHandler::new(backend);
29 //! let socket = std::path::Path("/path/to/socket");
30 //! let ex = cros_async::Executor::new()?;
31 //!
32 //! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //! eprintln!("error happened: {}", e);
34 //! }
35 //! Ok(())
36 //! }
37 //! ```
38 //!
39 // Implementation note:
40 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
41 // protocol. DeviceRequestHandler implements the VhostUserSlaveReqHandlerMut trait from vmm_vhost,
42 // and includes some common code for setting up guest memory and managing partially configured
43 // vrings. DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request()
44 // when it becomes readable. handle_request() reads and parses the message and then calls one of the
45 // VhostUserSlaveReqHandlerMut trait methods. These dispatch back to the supplied VhostUserBackend
46 // implementation (this is what our devices implement).
47
48 pub(super) mod sys;
49
50 use std::collections::BTreeMap;
51 use std::convert::From;
52 use std::convert::TryFrom;
53 use std::fs::File;
54 use std::num::Wrapping;
55 #[cfg(unix)]
56 use std::os::unix::io::AsRawFd;
57 use std::sync::Arc;
58
59 use anyhow::bail;
60 use anyhow::Context;
61 #[cfg(unix)]
62 use base::clear_fd_flags;
63 use base::error;
64 use base::Event;
65 use base::FromRawDescriptor;
66 use base::IntoRawDescriptor;
67 use base::Protection;
68 use base::SafeDescriptor;
69 use base::SharedMemory;
70 use sys::Doorbell;
71 use vm_control::VmMemorySource;
72 use vm_memory::GuestAddress;
73 use vm_memory::GuestMemory;
74 use vm_memory::MemoryRegion;
75 use vmm_vhost::connection::Endpoint;
76 use vmm_vhost::message::SlaveReq;
77 use vmm_vhost::message::VhostSharedMemoryRegion;
78 use vmm_vhost::message::VhostUserConfigFlags;
79 use vmm_vhost::message::VhostUserGpuMapMsg;
80 use vmm_vhost::message::VhostUserInflight;
81 use vmm_vhost::message::VhostUserMemoryRegion;
82 use vmm_vhost::message::VhostUserProtocolFeatures;
83 use vmm_vhost::message::VhostUserShmemMapMsg;
84 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
85 use vmm_vhost::message::VhostUserShmemUnmapMsg;
86 use vmm_vhost::message::VhostUserSingleMemoryRegion;
87 use vmm_vhost::message::VhostUserVirtioFeatures;
88 use vmm_vhost::message::VhostUserVringAddrFlags;
89 use vmm_vhost::message::VhostUserVringState;
90 use vmm_vhost::Error as VhostError;
91 use vmm_vhost::Protocol;
92 use vmm_vhost::Result as VhostResult;
93 use vmm_vhost::Slave;
94 use vmm_vhost::VhostUserMasterReqHandler;
95 use vmm_vhost::VhostUserSlaveReqHandlerMut;
96
97 use crate::virtio::Queue;
98 use crate::virtio::SharedMemoryMapper;
99 use crate::virtio::SharedMemoryRegion;
100 use crate::virtio::SignalableInterrupt;
101
102 /// Largest valid number of entries in a virtqueue.
103 const MAX_VRING_LEN: u16 = 32768;
104
105 /// An event to deliver an interrupt to the guest.
106 ///
107 /// Unlike `devices::Interrupt`, this doesn't support interrupt status and signal resampling.
108 // TODO(b/187487351): To avoid sending unnecessary events, we might want to support interrupt
109 // status. For this purpose, we need a mechanism to share interrupt status between the vmm and the
110 // device process.
111 #[derive(Clone)]
112 pub struct CallEvent(Arc<Event>);
113
114 impl CallEvent {
115 #[cfg_attr(windows, allow(dead_code))]
into_inner(self) -> Event116 pub fn into_inner(self) -> Event {
117 Arc::try_unwrap(self.0).unwrap()
118 }
119 }
120
121 impl SignalableInterrupt for CallEvent {
signal(&self, _vector: u16, _interrupt_status_mask: u32)122 fn signal(&self, _vector: u16, _interrupt_status_mask: u32) {
123 self.0.signal().unwrap();
124 }
125
signal_config_changed(&self)126 fn signal_config_changed(&self) {} // TODO(dgreid)
127
get_resample_evt(&self) -> Option<&Event>128 fn get_resample_evt(&self) -> Option<&Event> {
129 None
130 }
131
do_interrupt_resample(&self)132 fn do_interrupt_resample(&self) {}
133 }
134
135 impl From<File> for CallEvent {
from(file: File) -> Self136 fn from(file: File) -> Self {
137 // Safe because we own the file.
138 CallEvent(Arc::new(unsafe {
139 Event::from_raw_descriptor(file.into_raw_descriptor())
140 }))
141 }
142 }
143
144 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
145 /// used to translate messages from the vmm to guest offsets.
146 #[derive(Default)]
147 pub struct MappingInfo {
148 pub vmm_addr: u64,
149 pub guest_phys: u64,
150 pub size: u64,
151 }
152
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>153 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
154 for map in maps {
155 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
156 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
157 }
158 }
159 Err(VhostError::InvalidMessage)
160 }
161
162 /// Trait for vhost-user backend.
163 pub trait VhostUserBackend {
164 /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize165 fn max_queue_num(&self) -> usize;
166
167 /// The set of feature bits that this backend supports.
features(&self) -> u64168 fn features(&self) -> u64;
169
170 /// Acknowledges that this set of features should be enabled.
ack_features(&mut self, value: u64) -> anyhow::Result<()>171 fn ack_features(&mut self, value: u64) -> anyhow::Result<()>;
172
173 /// Returns the set of enabled features.
acked_features(&self) -> u64174 fn acked_features(&self) -> u64;
175
176 /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures177 fn protocol_features(&self) -> VhostUserProtocolFeatures;
178
179 /// Acknowledges that this set of protocol features should be enabled.
ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>180 fn ack_protocol_features(&mut self, _value: u64) -> anyhow::Result<()>;
181
182 /// Returns the set of enabled protocol features.
acked_protocol_features(&self) -> u64183 fn acked_protocol_features(&self) -> u64;
184
185 /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])186 fn read_config(&self, offset: u64, dst: &mut [u8]);
187
188 /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])189 fn write_config(&self, _offset: u64, _data: &[u8]) {}
190
191 /// Indicates that the backend should start processing requests for virtio queue number `idx`.
192 /// This method must not block the current thread so device backends should either spawn an
193 /// async task or another thread to handle messages from the Queue.
start_queue( &mut self, idx: usize, queue: Queue, mem: GuestMemory, doorbell: Doorbell, kick_evt: Event, ) -> anyhow::Result<()>194 fn start_queue(
195 &mut self,
196 idx: usize,
197 queue: Queue,
198 mem: GuestMemory,
199 doorbell: Doorbell,
200 kick_evt: Event,
201 ) -> anyhow::Result<()>;
202
203 /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
stop_queue(&mut self, idx: usize)204 fn stop_queue(&mut self, idx: usize);
205
206 /// Resets the vhost-user backend.
reset(&mut self)207 fn reset(&mut self);
208
209 /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>210 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
211 None
212 }
213
214 /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
215 /// handling.
216 ///
217 /// This method will be called when `VhostUserProtocolFeatures::SLAVE_REQ` is
218 /// negotiated.
set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection)219 fn set_backend_req_connection(&mut self, _conn: VhostBackendReqConnection) {
220 error!("set_backend_req_connection is not implemented");
221 }
222 }
223
224 /// A virtio ring entry.
225 struct Vring {
226 queue: Queue,
227 doorbell: Option<Doorbell>,
228 enabled: bool,
229 }
230
231 impl Vring {
new(max_size: u16) -> Self232 fn new(max_size: u16) -> Self {
233 Self {
234 queue: Queue::new(max_size),
235 doorbell: None,
236 enabled: false,
237 }
238 }
239
reset(&mut self)240 fn reset(&mut self) {
241 self.queue.reset();
242 self.doorbell = None;
243 self.enabled = false;
244 }
245 }
246
247 /// Trait for defining vhost-user ops that are platform-dependent.
248 pub trait VhostUserPlatformOps {
249 /// Returns the protocol implemented by these platform ops.
protocol(&self) -> Protocol250 fn protocol(&self) -> Protocol;
251 /// Create the guest memory for the backend.
252 ///
253 /// `contexts` and `files` must be the same size, and provide a description of the memory
254 /// regions to map as well as the file descriptors from which to obtain the memory backing these
255 /// regions, respectively.
256 ///
257 /// The returned tuple contains the constructed `GuestMemory` from these memory contexts, as
258 /// well as a vector describing all the mappings described by these contexts.
259
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>260 fn set_mem_table(
261 &mut self,
262 contexts: &[VhostUserMemoryRegion],
263 files: Vec<File>,
264 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>;
265
266 /// Return an `Event` that will be signaled by the frontend whenever vring `index` should be
267 /// processed.
268 ///
269 /// For protocols that support providing that event using a file descriptor (`Regular`), it is
270 /// provided by `file`. For other protocols, `file` will be `None`.
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<Event>271 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<Event>;
272
273 /// Return a `Doorbell` that the backend will signal whenever it puts used buffers for vring
274 /// `index`.
275 ///
276 /// For protocols that support listening to a file descriptor (`Regular`), `file` provides a
277 /// file descriptor from which the `Doorbell` should be built. For other protocols, it will be
278 /// `None`.
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<Doorbell>279 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<Doorbell>;
280 }
281
282 /// Ops for running vhost-user over a stream (i.e. regular protocol).
283 pub(super) struct VhostUserRegularOps;
284
285 impl VhostUserPlatformOps for VhostUserRegularOps {
protocol(&self) -> Protocol286 fn protocol(&self) -> Protocol {
287 Protocol::Regular
288 }
289
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>290 fn set_mem_table(
291 &mut self,
292 contexts: &[VhostUserMemoryRegion],
293 files: Vec<File>,
294 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
295 if files.len() != contexts.len() {
296 return Err(VhostError::InvalidParam);
297 }
298
299 let mut regions = Vec::with_capacity(files.len());
300 for (region, file) in contexts.iter().zip(files.into_iter()) {
301 let region = MemoryRegion::new_from_shm(
302 region.memory_size,
303 GuestAddress(region.guest_phys_addr),
304 region.mmap_offset,
305 Arc::new(
306 SharedMemory::from_safe_descriptor(
307 SafeDescriptor::from(file),
308 Some(region.memory_size),
309 )
310 .unwrap(),
311 ),
312 )
313 .map_err(|e| {
314 error!("failed to create a memory region: {}", e);
315 VhostError::InvalidOperation
316 })?;
317 regions.push(region);
318 }
319 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
320 error!("failed to create guest memory: {}", e);
321 VhostError::InvalidOperation
322 })?;
323
324 let vmm_maps = contexts
325 .iter()
326 .map(|region| MappingInfo {
327 vmm_addr: region.user_addr,
328 guest_phys: region.guest_phys_addr,
329 size: region.memory_size,
330 })
331 .collect();
332 Ok((guest_mem, vmm_maps))
333 }
334
set_vring_kick(&mut self, _index: u8, file: Option<File>) -> VhostResult<Event>335 fn set_vring_kick(&mut self, _index: u8, file: Option<File>) -> VhostResult<Event> {
336 let file = file.ok_or(VhostError::InvalidParam)?;
337 // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
338 // values via `next_val()` later.
339 // This is only required (and can only be done) on Unix platforms.
340 #[cfg(unix)]
341 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
342 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
343 return Err(VhostError::InvalidParam);
344 }
345
346 // Safe because we own the file.
347 Ok(unsafe { Event::from_raw_descriptor(file.into_raw_descriptor()) })
348 }
349
set_vring_call(&mut self, _index: u8, file: Option<File>) -> VhostResult<Doorbell>350 fn set_vring_call(&mut self, _index: u8, file: Option<File>) -> VhostResult<Doorbell> {
351 let file = file.ok_or(VhostError::InvalidParam)?;
352 Ok(
353 // `Doorbell` is defined as `CallEvent` on Windows, prevent clippy from giving us a
354 // warning about the unneeded conversion.
355 #[allow(clippy::useless_conversion)]
356 Doorbell::from(CallEvent::try_from(file).map_err(|_| {
357 error!("failed to convert callfd to CallSignal");
358 VhostError::InvalidParam
359 })?),
360 )
361 }
362 }
363
364 /// A request handler for devices implementing `VhostUserBackend`.
365 pub struct DeviceRequestHandler {
366 vrings: Vec<Vring>,
367 owned: bool,
368 vmm_maps: Option<Vec<MappingInfo>>,
369 mem: Option<GuestMemory>,
370 backend: Box<dyn VhostUserBackend>,
371 ops: Box<dyn VhostUserPlatformOps>,
372 }
373
374 impl DeviceRequestHandler {
375 /// Creates a vhost-user handler instance for `backend` with a different set of platform ops
376 /// than the regular vhost-user ones.
new( backend: Box<dyn VhostUserBackend>, ops: Box<dyn VhostUserPlatformOps>, ) -> Self377 pub(crate) fn new(
378 backend: Box<dyn VhostUserBackend>,
379 ops: Box<dyn VhostUserPlatformOps>,
380 ) -> Self {
381 let mut vrings = Vec::with_capacity(backend.max_queue_num());
382 for _ in 0..backend.max_queue_num() {
383 vrings.push(Vring::new(MAX_VRING_LEN));
384 }
385
386 DeviceRequestHandler {
387 vrings,
388 owned: false,
389 vmm_maps: None,
390 mem: None,
391 backend,
392 ops,
393 }
394 }
395 }
396
397 impl VhostUserSlaveReqHandlerMut for DeviceRequestHandler {
protocol(&self) -> Protocol398 fn protocol(&self) -> Protocol {
399 self.ops.protocol()
400 }
401
set_owner(&mut self) -> VhostResult<()>402 fn set_owner(&mut self) -> VhostResult<()> {
403 if self.owned {
404 return Err(VhostError::InvalidOperation);
405 }
406 self.owned = true;
407 Ok(())
408 }
409
reset_owner(&mut self) -> VhostResult<()>410 fn reset_owner(&mut self) -> VhostResult<()> {
411 self.owned = false;
412 self.backend.reset();
413 Ok(())
414 }
415
get_features(&mut self) -> VhostResult<u64>416 fn get_features(&mut self) -> VhostResult<u64> {
417 let features = self.backend.features();
418 Ok(features)
419 }
420
set_features(&mut self, features: u64) -> VhostResult<()>421 fn set_features(&mut self, features: u64) -> VhostResult<()> {
422 if !self.owned {
423 return Err(VhostError::InvalidOperation);
424 }
425
426 if (features & !(self.backend.features())) != 0 {
427 return Err(VhostError::InvalidParam);
428 }
429
430 if let Err(e) = self.backend.ack_features(features) {
431 error!("failed to acknowledge features 0x{:x}: {}", features, e);
432 return Err(VhostError::InvalidOperation);
433 }
434
435 // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
436 // enabled state.
437 // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
438 // disabled state.
439 // Client must not pass data to/from the backend until ring is enabled by
440 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
441 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
442 let acked_features = self.backend.acked_features();
443 let vring_enabled = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() & acked_features != 0;
444 for v in &mut self.vrings {
445 v.enabled = vring_enabled;
446 }
447
448 Ok(())
449 }
450
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>451 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
452 Ok(self.backend.protocol_features())
453 }
454
set_protocol_features(&mut self, features: u64) -> VhostResult<()>455 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
456 if let Err(e) = self.backend.ack_protocol_features(features) {
457 error!("failed to set protocol features 0x{:x}: {}", features, e);
458 return Err(VhostError::InvalidOperation);
459 }
460 Ok(())
461 }
462
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>463 fn set_mem_table(
464 &mut self,
465 contexts: &[VhostUserMemoryRegion],
466 files: Vec<File>,
467 ) -> VhostResult<()> {
468 let (guest_mem, vmm_maps) = self.ops.set_mem_table(contexts, files)?;
469 self.mem = Some(guest_mem);
470 self.vmm_maps = Some(vmm_maps);
471 Ok(())
472 }
473
get_queue_num(&mut self) -> VhostResult<u64>474 fn get_queue_num(&mut self) -> VhostResult<u64> {
475 Ok(self.vrings.len() as u64)
476 }
477
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>478 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
479 if index as usize >= self.vrings.len() || num == 0 || num > MAX_VRING_LEN.into() {
480 return Err(VhostError::InvalidParam);
481 }
482 self.vrings[index as usize].queue.set_size(num as u16);
483
484 Ok(())
485 }
486
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>487 fn set_vring_addr(
488 &mut self,
489 index: u32,
490 _flags: VhostUserVringAddrFlags,
491 descriptor: u64,
492 used: u64,
493 available: u64,
494 _log: u64,
495 ) -> VhostResult<()> {
496 if index as usize >= self.vrings.len() {
497 return Err(VhostError::InvalidParam);
498 }
499
500 let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
501 let vring = &mut self.vrings[index as usize];
502 vring
503 .queue
504 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
505 vring
506 .queue
507 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
508 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
509
510 Ok(())
511 }
512
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>513 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
514 if index as usize >= self.vrings.len() || base >= MAX_VRING_LEN.into() {
515 return Err(VhostError::InvalidParam);
516 }
517
518 let vring = &mut self.vrings[index as usize];
519 vring.queue.next_avail = Wrapping(base as u16);
520 vring.queue.next_used = Wrapping(base as u16);
521
522 Ok(())
523 }
524
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>525 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
526 if index as usize >= self.vrings.len() {
527 return Err(VhostError::InvalidParam);
528 }
529
530 // Quotation from vhost-user spec:
531 // Client must start ring upon receiving a kick (that is, detecting
532 // that file descriptor is readable) on the descriptor specified by
533 // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
534 // VHOST_USER_GET_VRING_BASE.
535 self.backend.stop_queue(index as usize);
536
537 let vring = &mut self.vrings[index as usize];
538 vring.reset();
539
540 Ok(VhostUserVringState::new(
541 index,
542 vring.queue.next_avail.0 as u32,
543 ))
544 }
545
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>546 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
547 if index as usize >= self.vrings.len() {
548 return Err(VhostError::InvalidParam);
549 }
550
551 let vring = &mut self.vrings[index as usize];
552 if vring.queue.ready() {
553 error!("kick fd cannot replaced after queue is started");
554 return Err(VhostError::InvalidOperation);
555 }
556
557 let kick_evt = self.ops.set_vring_kick(index, file)?;
558
559 // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
560 vring.queue.ack_features(self.backend.acked_features());
561 vring.queue.set_ready(true);
562
563 let queue = match vring.queue.activate() {
564 Ok(queue) => queue,
565 Err(e) => {
566 error!("failed to activate vring: {:#}", e);
567 return Err(VhostError::SlaveInternalError);
568 }
569 };
570
571 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
572 let mem = self
573 .mem
574 .as_ref()
575 .cloned()
576 .ok_or(VhostError::InvalidOperation)?;
577
578 if let Err(e) = self
579 .backend
580 .start_queue(index as usize, queue, mem, doorbell, kick_evt)
581 {
582 error!("Failed to start queue {}: {}", index, e);
583 return Err(VhostError::SlaveInternalError);
584 }
585
586 Ok(())
587 }
588
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>589 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
590 if index as usize >= self.vrings.len() {
591 return Err(VhostError::InvalidParam);
592 }
593
594 let doorbell = self.ops.set_vring_call(index, file)?;
595 self.vrings[index as usize].doorbell = Some(doorbell);
596 Ok(())
597 }
598
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>599 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
600 // TODO
601 Ok(())
602 }
603
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>604 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
605 if index as usize >= self.vrings.len() {
606 return Err(VhostError::InvalidParam);
607 }
608
609 // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
610 // has been negotiated.
611 if self.backend.acked_features() & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
612 return Err(VhostError::InvalidOperation);
613 }
614
615 // Slave must not pass data to/from the backend until ring is
616 // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
617 // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
618 // with parameter 0.
619 self.vrings[index as usize].enabled = enable;
620
621 Ok(())
622 }
623
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>624 fn get_config(
625 &mut self,
626 offset: u32,
627 size: u32,
628 _flags: VhostUserConfigFlags,
629 ) -> VhostResult<Vec<u8>> {
630 let mut data = vec![0; size as usize];
631 self.backend.read_config(u64::from(offset), &mut data);
632 Ok(data)
633 }
634
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>635 fn set_config(
636 &mut self,
637 offset: u32,
638 buf: &[u8],
639 _flags: VhostUserConfigFlags,
640 ) -> VhostResult<()> {
641 self.backend.write_config(u64::from(offset), buf);
642 Ok(())
643 }
644
set_slave_req_fd(&mut self, ep: Box<dyn Endpoint<SlaveReq>>)645 fn set_slave_req_fd(&mut self, ep: Box<dyn Endpoint<SlaveReq>>) {
646 let conn = VhostBackendReqConnection::new(
647 Slave::new(ep),
648 self.backend.get_shared_memory_region().map(|r| r.id),
649 );
650 self.backend.set_backend_req_connection(conn);
651 }
652
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>653 fn get_inflight_fd(
654 &mut self,
655 _inflight: &VhostUserInflight,
656 ) -> VhostResult<(VhostUserInflight, File)> {
657 unimplemented!("get_inflight_fd");
658 }
659
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>660 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
661 unimplemented!("set_inflight_fd");
662 }
663
get_max_mem_slots(&mut self) -> VhostResult<u64>664 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
665 //TODO
666 Ok(0)
667 }
668
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>669 fn add_mem_region(
670 &mut self,
671 _region: &VhostUserSingleMemoryRegion,
672 _fd: File,
673 ) -> VhostResult<()> {
674 //TODO
675 Ok(())
676 }
677
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>678 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
679 //TODO
680 Ok(())
681 }
682
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>683 fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
684 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
685 vec![VhostSharedMemoryRegion::new(r.id, r.length)]
686 } else {
687 Vec::new()
688 })
689 }
690 }
691
692 /// Indicates the state of backend request connection
693 pub enum VhostBackendReqConnectionState {
694 /// A backend request connection (`VhostBackendReqConnection`) is established
695 Connected(VhostBackendReqConnection),
696 /// No backend request connection has been established yet
697 NoConnection,
698 }
699
700 /// Keeps track of Vhost user backend request connection.
701 pub struct VhostBackendReqConnection {
702 conn: Slave,
703 shmem_info: Option<ShmemInfo>,
704 }
705
706 #[derive(Clone)]
707 struct ShmemInfo {
708 shmid: u8,
709 mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
710 }
711
712 impl VhostBackendReqConnection {
new(conn: Slave, shmid: Option<u8>) -> Self713 pub fn new(conn: Slave, shmid: Option<u8>) -> Self {
714 let shmem_info = shmid.map(|shmid| ShmemInfo {
715 shmid,
716 mapped_regions: BTreeMap::new(),
717 });
718 Self { conn, shmem_info }
719 }
720
721 /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>722 pub fn send_config_changed(&self) -> anyhow::Result<()> {
723 self.conn
724 .handle_config_change()
725 .context("Could not send config change message")?;
726 Ok(())
727 }
728
729 /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&mut self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>730 pub fn take_shmem_mapper(&mut self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
731 let shmem_info = self
732 .shmem_info
733 .take()
734 .context("could not take shared memory mapper information")?;
735
736 Ok(Box::new(VhostShmemMapper {
737 conn: self.conn.clone(),
738 shmem_info,
739 }))
740 }
741 }
742
743 struct VhostShmemMapper {
744 conn: Slave,
745 shmem_info: ShmemInfo,
746 }
747
748 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, ) -> anyhow::Result<()>749 fn add_mapping(
750 &mut self,
751 source: VmMemorySource,
752 offset: u64,
753 prot: Protection,
754 ) -> anyhow::Result<()> {
755 // True if we should send gpu_map instead of shmem_map.
756 let is_gpu = matches!(&source, &VmMemorySource::Vulkan { .. });
757
758 let size = if is_gpu {
759 match source {
760 VmMemorySource::Vulkan {
761 descriptor,
762 handle_type,
763 memory_idx,
764 device_id,
765 size,
766 } => {
767 let msg = VhostUserGpuMapMsg::new(
768 self.shmem_info.shmid,
769 offset,
770 size,
771 memory_idx,
772 handle_type,
773 device_id.device_uuid,
774 device_id.driver_uuid,
775 );
776 self.conn
777 .gpu_map(&msg, &descriptor)
778 .context("failed to map memory")?;
779 size
780 }
781 _ => unreachable!("inconsistent pattern match"),
782 }
783 } else {
784 let (descriptor, fd_offset, size) = match source {
785 VmMemorySource::Descriptor {
786 descriptor,
787 offset,
788 size,
789 } => (descriptor, offset, size),
790 VmMemorySource::SharedMemory(shmem) => {
791 let size = shmem.size();
792 // Safe because we own shmem.
793 let descriptor =
794 unsafe { SafeDescriptor::from_raw_descriptor(shmem.into_raw_descriptor()) };
795 (descriptor, 0, size)
796 }
797 _ => bail!("unsupported source"),
798 };
799 let flags = VhostUserShmemMapMsgFlags::from(prot);
800 let msg =
801 VhostUserShmemMapMsg::new(self.shmem_info.shmid, offset, fd_offset, size, flags);
802 self.conn
803 .shmem_map(&msg, &descriptor)
804 .context("failed to map memory")?;
805 size
806 };
807
808 self.shmem_info.mapped_regions.insert(offset, size);
809 Ok(())
810 }
811
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>812 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
813 let size = self
814 .shmem_info
815 .mapped_regions
816 .remove(&offset)
817 .context("unknown offset")?;
818 let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
819 self.conn
820 .shmem_unmap(&msg)
821 .context("failed to map memory")
822 .map(|_| ())
823 }
824 }
825
826 #[cfg(test)]
827 mod tests {
828 #[cfg(unix)]
829 use std::sync::mpsc::channel;
830 #[cfg(unix)]
831 use std::sync::Barrier;
832
833 use anyhow::anyhow;
834 use anyhow::bail;
835 #[cfg(unix)]
836 use tempfile::Builder;
837 #[cfg(unix)]
838 use tempfile::TempDir;
839 use vmm_vhost::message::MasterReq;
840 use vmm_vhost::SlaveReqHandler;
841 use vmm_vhost::VhostUserSlaveReqHandler;
842 use zerocopy::AsBytes;
843 use zerocopy::FromBytes;
844
845 use super::*;
846 use crate::virtio::vhost::user::vmm::VhostUserHandler;
847
848 #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromBytes)]
849 #[repr(C, packed(4))]
850 struct FakeConfig {
851 x: u32,
852 y: u64,
853 }
854
855 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
856
857 pub(super) struct FakeBackend {
858 avail_features: u64,
859 acked_features: u64,
860 acked_protocol_features: VhostUserProtocolFeatures,
861 }
862
863 impl FakeBackend {
864 const MAX_QUEUE_NUM: usize = 16;
865
new() -> Self866 pub(super) fn new() -> Self {
867 Self {
868 avail_features: VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(),
869 acked_features: 0,
870 acked_protocol_features: VhostUserProtocolFeatures::empty(),
871 }
872 }
873 }
874
875 impl VhostUserBackend for FakeBackend {
max_queue_num(&self) -> usize876 fn max_queue_num(&self) -> usize {
877 Self::MAX_QUEUE_NUM
878 }
879
features(&self) -> u64880 fn features(&self) -> u64 {
881 self.avail_features
882 }
883
ack_features(&mut self, value: u64) -> anyhow::Result<()>884 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
885 let unrequested_features = value & !self.avail_features;
886 if unrequested_features != 0 {
887 bail!(
888 "invalid protocol features are given: 0x{:x}",
889 unrequested_features
890 );
891 }
892 self.acked_features |= value;
893 Ok(())
894 }
895
acked_features(&self) -> u64896 fn acked_features(&self) -> u64 {
897 self.acked_features
898 }
899
protocol_features(&self) -> VhostUserProtocolFeatures900 fn protocol_features(&self) -> VhostUserProtocolFeatures {
901 VhostUserProtocolFeatures::CONFIG
902 }
903
ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()>904 fn ack_protocol_features(&mut self, features: u64) -> anyhow::Result<()> {
905 let features = VhostUserProtocolFeatures::from_bits(features).ok_or(anyhow!(
906 "invalid protocol features are given: 0x{:x}",
907 features
908 ))?;
909 let supported = self.protocol_features();
910 self.acked_protocol_features = features & supported;
911 Ok(())
912 }
913
acked_protocol_features(&self) -> u64914 fn acked_protocol_features(&self) -> u64 {
915 self.acked_protocol_features.bits()
916 }
917
read_config(&self, offset: u64, dst: &mut [u8])918 fn read_config(&self, offset: u64, dst: &mut [u8]) {
919 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
920 }
921
reset(&mut self)922 fn reset(&mut self) {}
923
start_queue( &mut self, _idx: usize, _queue: Queue, _mem: GuestMemory, _doorbell: Doorbell, _kick_evt: Event, ) -> anyhow::Result<()>924 fn start_queue(
925 &mut self,
926 _idx: usize,
927 _queue: Queue,
928 _mem: GuestMemory,
929 _doorbell: Doorbell,
930 _kick_evt: Event,
931 ) -> anyhow::Result<()> {
932 Ok(())
933 }
934
stop_queue(&mut self, _idx: usize)935 fn stop_queue(&mut self, _idx: usize) {}
936 }
937
938 #[cfg(unix)]
temp_dir() -> TempDir939 fn temp_dir() -> TempDir {
940 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
941 }
942
943 #[cfg(unix)]
944 #[test]
test_vhost_user_activate()945 fn test_vhost_user_activate() {
946 use std::os::unix::net::UnixStream;
947
948 use vmm_vhost::connection::socket::Listener as SocketListener;
949 use vmm_vhost::SlaveListener;
950
951 const QUEUES_NUM: usize = 2;
952
953 let dir = temp_dir();
954 let mut path = dir.path().to_owned();
955 path.push("sock");
956 let listener = SocketListener::new(&path, true).unwrap();
957
958 let vmm_bar = Arc::new(Barrier::new(2));
959 let dev_bar = vmm_bar.clone();
960
961 let (tx, rx) = channel();
962
963 std::thread::spawn(move || {
964 // VMM side
965 rx.recv().unwrap(); // Ensure the device is ready.
966
967 let allow_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
968 let init_features = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
969 let allow_protocol_features = VhostUserProtocolFeatures::CONFIG;
970 let connection = UnixStream::connect(&path).unwrap();
971 let mut vmm_handler = VhostUserHandler::new_from_connection(
972 connection,
973 QUEUES_NUM as u64,
974 allow_features,
975 init_features,
976 allow_protocol_features,
977 )
978 .unwrap();
979
980 println!("read_config");
981 let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
982 vmm_handler.read_config(0, &mut buf).unwrap();
983 // Check if the obtained config data is correct.
984 let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
985 assert_eq!(config, FAKE_CONFIG_DATA);
986
987 println!("set_mem_table");
988 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
989 vmm_handler.set_mem_table(&mem).unwrap();
990
991 for idx in 0..QUEUES_NUM {
992 println!("activate_mem_table: queue_index={}", idx);
993 let queue = Queue::new(0x10);
994 let queue_evt = Event::new().unwrap();
995 let irqfd = Event::new().unwrap();
996
997 vmm_handler
998 .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd)
999 .unwrap();
1000 }
1001
1002 // The VMM side is supposed to stop before the device side.
1003 drop(vmm_handler);
1004
1005 vmm_bar.wait();
1006 });
1007
1008 // Device side
1009 let handler = std::sync::Mutex::new(DeviceRequestHandler::new(
1010 Box::new(FakeBackend::new()),
1011 Box::new(VhostUserRegularOps),
1012 ));
1013 let mut listener = SlaveListener::<SocketListener, _>::new(listener, handler).unwrap();
1014
1015 // Notify listener is ready.
1016 tx.send(()).unwrap();
1017
1018 let mut listener = listener.accept().unwrap().unwrap();
1019
1020 // VhostUserHandler::new()
1021 handle_request(&mut listener).expect("set_owner");
1022 handle_request(&mut listener).expect("get_features");
1023 handle_request(&mut listener).expect("set_features");
1024 handle_request(&mut listener).expect("get_protocol_features");
1025 handle_request(&mut listener).expect("set_protocol_features");
1026
1027 // VhostUserHandler::read_config()
1028 handle_request(&mut listener).expect("get_config");
1029
1030 // VhostUserHandler::set_mem_table()
1031 handle_request(&mut listener).expect("set_mem_table");
1032
1033 for _ in 0..QUEUES_NUM {
1034 // VhostUserHandler::activate_vring()
1035 handle_request(&mut listener).expect("set_vring_num");
1036 handle_request(&mut listener).expect("set_vring_addr");
1037 handle_request(&mut listener).expect("set_vring_base");
1038 handle_request(&mut listener).expect("set_vring_call");
1039 handle_request(&mut listener).expect("set_vring_kick");
1040 handle_request(&mut listener).expect("set_vring_enable");
1041 }
1042
1043 dev_bar.wait();
1044
1045 match handle_request(&mut listener) {
1046 Err(VhostError::ClientExit) => (),
1047 r => panic!("Err(ClientExit) was expected but {:?}", r),
1048 }
1049 }
1050
vmm_handler_send_requests(vmm_handler: &mut VhostUserHandler, queues_num: usize)1051 pub(super) fn vmm_handler_send_requests(vmm_handler: &mut VhostUserHandler, queues_num: usize) {
1052 println!("read_config");
1053 let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1054 vmm_handler.read_config(0, &mut buf).unwrap();
1055 // Check if the obtained config data is correct.
1056 let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1057 assert_eq!(config, FAKE_CONFIG_DATA);
1058
1059 println!("set_mem_table");
1060 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1061 vmm_handler.set_mem_table(&mem).unwrap();
1062
1063 for idx in 0..queues_num {
1064 println!("activate_mem_table: queue_index={}", idx);
1065 let queue = Queue::new(0x10);
1066 let queue_evt = Event::new().unwrap();
1067 let irqfd = Event::new().unwrap();
1068
1069 vmm_handler
1070 .activate_vring(&mem, idx, &queue, &queue_evt, &irqfd)
1071 .unwrap();
1072 }
1073 }
1074
handle_request<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( handler: &mut SlaveReqHandler<S, E>, ) -> Result<(), VhostError>1075 fn handle_request<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>(
1076 handler: &mut SlaveReqHandler<S, E>,
1077 ) -> Result<(), VhostError> {
1078 let (hdr, files) = handler.recv_header()?;
1079 handler.process_message(hdr, files)
1080 }
1081
test_handle_requests<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>( req_handler: &mut SlaveReqHandler<S, E>, queues_num: usize, )1082 pub(super) fn test_handle_requests<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>(
1083 req_handler: &mut SlaveReqHandler<S, E>,
1084 queues_num: usize,
1085 ) {
1086 // VhostUserHandler::new()
1087 handle_request(req_handler).expect("set_owner");
1088 handle_request(req_handler).expect("get_features");
1089 handle_request(req_handler).expect("set_features");
1090 handle_request(req_handler).expect("get_protocol_features");
1091 handle_request(req_handler).expect("set_protocol_features");
1092
1093 // VhostUserHandler::read_config()
1094 handle_request(req_handler).expect("get_config");
1095
1096 // VhostUserHandler::set_mem_table()
1097 handle_request(req_handler).expect("set_mem_table");
1098
1099 for _ in 0..queues_num {
1100 // VhostUserHandler::activate_vring()
1101 handle_request(req_handler).expect("set_vring_num");
1102 handle_request(req_handler).expect("set_vring_addr");
1103 handle_request(req_handler).expect("set_vring_base");
1104 handle_request(req_handler).expect("set_vring_call");
1105 handle_request(req_handler).expect("set_vring_kick");
1106 handle_request(req_handler).expect("set_vring_enable");
1107 }
1108 }
1109 }
1110