1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 mod sys;
6 mod worker;
7
8 use std::sync::Mutex;
9
10 use base::error;
11 use base::info;
12 use base::AsRawDescriptor;
13 use base::Event;
14 use base::Protection;
15 use base::SafeDescriptor;
16 use base::WorkerThread;
17 use rutabaga_gfx::DeviceId;
18 use vm_control::VmMemorySource;
19 use vm_memory::GuestMemory;
20 use vm_memory::MemoryRegionInformation;
21 use vmm_vhost::message::VhostUserConfigFlags;
22 use vmm_vhost::message::VhostUserGpuMapMsg;
23 use vmm_vhost::message::VhostUserProtocolFeatures;
24 use vmm_vhost::message::VhostUserShmemMapMsg;
25 use vmm_vhost::message::VhostUserShmemUnmapMsg;
26 use vmm_vhost::message::VhostUserVirtioFeatures;
27 use vmm_vhost::HandlerResult;
28 use vmm_vhost::MasterReqHandler;
29 use vmm_vhost::VhostBackend;
30 use vmm_vhost::VhostUserMaster;
31 use vmm_vhost::VhostUserMasterReqHandlerMut;
32 use vmm_vhost::VhostUserMemoryRegionInfo;
33 use vmm_vhost::VringConfigData;
34
35 use crate::virtio::vhost::user::vmm::handler::sys::create_backend_req_handler;
36 use crate::virtio::vhost::user::vmm::handler::sys::SocketMaster;
37 use crate::virtio::vhost::user::vmm::Error;
38 use crate::virtio::vhost::user::vmm::Result;
39 use crate::virtio::Interrupt;
40 use crate::virtio::Queue;
41 use crate::virtio::SharedMemoryMapper;
42 use crate::virtio::SharedMemoryRegion;
43 use crate::virtio::SignalableInterrupt;
44
45 type BackendReqHandler = MasterReqHandler<Mutex<BackendReqHandlerImpl>>;
46
set_features(vu: &mut SocketMaster, avail_features: u64, ack_features: u64) -> Result<u64>47 fn set_features(vu: &mut SocketMaster, avail_features: u64, ack_features: u64) -> Result<u64> {
48 let features = avail_features & ack_features;
49 vu.set_features(features).map_err(Error::SetFeatures)?;
50 Ok(features)
51 }
52
53 pub struct VhostUserHandler {
54 vu: SocketMaster,
55 pub avail_features: u64,
56 acked_features: u64,
57 protocol_features: VhostUserProtocolFeatures,
58 backend_req_handler: Option<BackendReqHandler>,
59 // Shared memory region info. IPC result from backend is saved with outer Option.
60 shmem_region: Option<Option<SharedMemoryRegion>>,
61 // On Windows, we need a backend pid to support backend requests.
62 #[cfg(windows)]
63 backend_pid: Option<u32>,
64 }
65
66 impl VhostUserHandler {
67 /// Creates a `VhostUserHandler` instance with features and protocol features initialized.
new( mut vu: SocketMaster, allow_features: u64, init_features: u64, allow_protocol_features: VhostUserProtocolFeatures, #[cfg(windows)] backend_pid: Option<u32>, ) -> Result<Self>68 fn new(
69 mut vu: SocketMaster,
70 allow_features: u64,
71 init_features: u64,
72 allow_protocol_features: VhostUserProtocolFeatures,
73 #[cfg(windows)] backend_pid: Option<u32>,
74 ) -> Result<Self> {
75 vu.set_owner().map_err(Error::SetOwner)?;
76
77 let avail_features = allow_features & vu.get_features().map_err(Error::GetFeatures)?;
78 let acked_features = set_features(&mut vu, avail_features, init_features)?;
79
80 let mut protocol_features = VhostUserProtocolFeatures::empty();
81 if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
82 let avail_protocol_features = vu
83 .get_protocol_features()
84 .map_err(Error::GetProtocolFeatures)?;
85 protocol_features = allow_protocol_features & avail_protocol_features;
86 vu.set_protocol_features(protocol_features)
87 .map_err(Error::SetProtocolFeatures)?;
88 }
89
90 // if protocol feature `VhostUserProtocolFeatures::SLAVE_REQ` is negotiated.
91 let backend_req_handler =
92 if protocol_features.contains(VhostUserProtocolFeatures::SLAVE_REQ) {
93 let mut handler = create_backend_req_handler(
94 BackendReqHandlerImpl {
95 interrupt: None,
96 shared_mapper_state: None,
97 },
98 #[cfg(windows)]
99 backend_pid,
100 )?;
101 vu.set_slave_request_fd(&handler.take_tx_descriptor())
102 .map_err(Error::SetDeviceRequestChannel)?;
103 Some(handler)
104 } else {
105 None
106 };
107
108 Ok(VhostUserHandler {
109 vu,
110 avail_features,
111 acked_features,
112 protocol_features,
113 backend_req_handler,
114 shmem_region: None,
115 #[cfg(windows)]
116 backend_pid,
117 })
118 }
119
120 /// Returns a vector of sizes of each queue.
queue_sizes(&mut self, queue_size: u16, default_queues_num: usize) -> Result<Vec<u16>>121 pub fn queue_sizes(&mut self, queue_size: u16, default_queues_num: usize) -> Result<Vec<u16>> {
122 let queues_num = if self
123 .protocol_features
124 .contains(VhostUserProtocolFeatures::MQ)
125 {
126 self.vu.get_queue_num().map_err(Error::GetQueueNum)? as usize
127 } else {
128 default_queues_num
129 };
130 Ok(vec![queue_size; queues_num])
131 }
132
133 /// Enables a set of features.
ack_features(&mut self, ack_features: u64) -> Result<()>134 pub fn ack_features(&mut self, ack_features: u64) -> Result<()> {
135 let features = set_features(
136 &mut self.vu,
137 self.avail_features,
138 self.acked_features | ack_features,
139 )?;
140 self.acked_features = features;
141 Ok(())
142 }
143
144 /// Gets the device configuration space at `offset` and writes it into `data`.
read_config(&mut self, offset: u64, data: &mut [u8]) -> Result<()>145 pub fn read_config(&mut self, offset: u64, data: &mut [u8]) -> Result<()> {
146 let (_, config) = self
147 .vu
148 .get_config(
149 offset
150 .try_into()
151 .map_err(|_| Error::InvalidConfigOffset(offset))?,
152 data.len()
153 .try_into()
154 .map_err(|_| Error::InvalidConfigLen(data.len()))?,
155 VhostUserConfigFlags::WRITABLE,
156 data,
157 )
158 .map_err(Error::GetConfig)?;
159 data.copy_from_slice(&config);
160 Ok(())
161 }
162
163 /// Writes `data` into the device configuration space at `offset`.
write_config(&mut self, offset: u64, data: &[u8]) -> Result<()>164 pub fn write_config(&mut self, offset: u64, data: &[u8]) -> Result<()> {
165 self.vu
166 .set_config(
167 offset
168 .try_into()
169 .map_err(|_| Error::InvalidConfigOffset(offset))?,
170 VhostUserConfigFlags::empty(),
171 data,
172 )
173 .map_err(Error::SetConfig)
174 }
175
176 /// Sets the memory map regions so it can translate the vring addresses.
set_mem_table(&mut self, mem: &GuestMemory) -> Result<()>177 pub fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
178 let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
179 mem.with_regions::<_, ()>(
180 |MemoryRegionInformation {
181 guest_addr,
182 size,
183 host_addr,
184 shm,
185 shm_offset,
186 ..
187 }| {
188 let region = VhostUserMemoryRegionInfo {
189 guest_phys_addr: guest_addr.0,
190 memory_size: size as u64,
191 userspace_addr: host_addr as u64,
192 mmap_offset: shm_offset as u64,
193 mmap_handle: shm.as_raw_descriptor(),
194 };
195 regions.push(region);
196 Ok(())
197 },
198 )
199 .unwrap(); // never fail
200
201 self.vu
202 .set_mem_table(regions.as_slice())
203 .map_err(Error::SetMemTable)?;
204
205 Ok(())
206 }
207
208 /// Activates a vring for the given `queue`.
activate_vring( &mut self, mem: &GuestMemory, queue_index: usize, queue: &Queue, queue_evt: &Event, irqfd: &Event, ) -> Result<()>209 pub fn activate_vring(
210 &mut self,
211 mem: &GuestMemory,
212 queue_index: usize,
213 queue: &Queue,
214 queue_evt: &Event,
215 irqfd: &Event,
216 ) -> Result<()> {
217 self.vu
218 .set_vring_num(queue_index, queue.size())
219 .map_err(Error::SetVringNum)?;
220
221 let config_data = VringConfigData {
222 queue_max_size: queue.max_size(),
223 queue_size: queue.size(),
224 flags: 0u32,
225 desc_table_addr: mem
226 .get_host_address(queue.desc_table())
227 .map_err(Error::GetHostAddress)? as u64,
228 used_ring_addr: mem
229 .get_host_address(queue.used_ring())
230 .map_err(Error::GetHostAddress)? as u64,
231 avail_ring_addr: mem
232 .get_host_address(queue.avail_ring())
233 .map_err(Error::GetHostAddress)? as u64,
234 log_addr: None,
235 };
236 self.vu
237 .set_vring_addr(queue_index, &config_data)
238 .map_err(Error::SetVringAddr)?;
239
240 self.vu
241 .set_vring_base(queue_index, 0)
242 .map_err(Error::SetVringBase)?;
243
244 self.vu
245 .set_vring_call(queue_index, irqfd)
246 .map_err(Error::SetVringCall)?;
247 self.vu
248 .set_vring_kick(queue_index, queue_evt)
249 .map_err(Error::SetVringKick)?;
250 self.vu
251 .set_vring_enable(queue_index, true)
252 .map_err(Error::SetVringEnable)?;
253
254 Ok(())
255 }
256
257 /// Activates vrings.
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: Vec<(Queue, Event)>, label: &str, ) -> Result<WorkerThread<()>>258 pub fn activate(
259 &mut self,
260 mem: GuestMemory,
261 interrupt: Interrupt,
262 queues: Vec<(Queue, Event)>,
263 label: &str,
264 ) -> Result<WorkerThread<()>> {
265 self.set_mem_table(&mem)?;
266
267 let msix_config_opt = interrupt
268 .get_msix_config()
269 .as_ref()
270 .ok_or(Error::MsixConfigUnavailable)?;
271 let msix_config = msix_config_opt.lock();
272
273 let non_msix_evt = Event::new().map_err(Error::CreateEvent)?;
274 for (queue_index, (queue, queue_evt)) in queues.iter().enumerate() {
275 let irqfd = msix_config
276 .get_irqfd(queue.vector() as usize)
277 .unwrap_or(&non_msix_evt);
278 self.activate_vring(&mem, queue_index, queue, queue_evt, irqfd)?;
279 }
280
281 drop(msix_config);
282
283 let label = format!("vhost_user_virtio_{}", label);
284
285 let backend_req_handler = self.backend_req_handler.take();
286 if let Some(handler) = &backend_req_handler {
287 // Using unwrap here to get the mutex protected value
288 handler
289 .backend()
290 .lock()
291 .unwrap()
292 .set_interrupt(interrupt.clone());
293 }
294
295 Ok(WorkerThread::start(label.clone(), move |kill_evt| {
296 let mut worker = worker::Worker {
297 queues,
298 mem,
299 kill_evt,
300 non_msix_evt,
301 backend_req_handler,
302 };
303
304 if let Err(e) = worker.run(interrupt) {
305 error!("failed to start {} worker: {}", label, e);
306 }
307 }))
308 }
309
310 /// Deactivates all vrings.
reset(&mut self, queues_num: usize) -> Result<()>311 pub fn reset(&mut self, queues_num: usize) -> Result<()> {
312 for queue_index in 0..queues_num {
313 self.vu
314 .set_vring_enable(queue_index, false)
315 .map_err(Error::SetVringEnable)?;
316 self.vu
317 .get_vring_base(queue_index)
318 .map_err(Error::GetVringBase)?;
319 }
320 Ok(())
321 }
322
get_shared_memory_region(&mut self) -> Result<Option<SharedMemoryRegion>>323 pub fn get_shared_memory_region(&mut self) -> Result<Option<SharedMemoryRegion>> {
324 if !self
325 .protocol_features
326 .contains(VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS)
327 {
328 return Ok(None);
329 }
330 if let Some(r) = self.shmem_region.as_ref() {
331 return Ok(r.clone());
332 }
333 let regions = self
334 .vu
335 .get_shared_memory_regions()
336 .map_err(Error::ShmemRegions)?;
337 let region = match regions.len() {
338 0 => None,
339 1 => Some(SharedMemoryRegion {
340 id: regions[0].id,
341 length: regions[0].length,
342 }),
343 n => return Err(Error::TooManyShmemRegions(n)),
344 };
345
346 self.shmem_region = Some(region.clone());
347 Ok(region)
348 }
349
set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) -> Result<()>350 pub fn set_shared_memory_mapper(&mut self, mapper: Box<dyn SharedMemoryMapper>) -> Result<()> {
351 // Return error if backend request handler is not available. This indicates
352 // that `VhostUserProtocolFeatures::SLAVE_REQ` is not negotiated.
353 let backend_req_handler =
354 self.backend_req_handler
355 .as_mut()
356 .ok_or(Error::ProtocolFeatureNotNegoiated(
357 VhostUserProtocolFeatures::SLAVE_REQ,
358 ))?;
359
360 // The virtio framework will only call this if get_shared_memory_region returned a region
361 let shmid = self
362 .shmem_region
363 .clone()
364 .flatten()
365 .expect("missing shmid")
366 .id;
367
368 backend_req_handler
369 .backend()
370 .lock()
371 .unwrap()
372 .set_shared_mapper_state(SharedMapperState { mapper, shmid });
373 Ok(())
374 }
375 }
376
377 struct SharedMapperState {
378 mapper: Box<dyn SharedMemoryMapper>,
379 shmid: u8,
380 }
381
382 pub struct BackendReqHandlerImpl {
383 interrupt: Option<Interrupt>,
384 shared_mapper_state: Option<SharedMapperState>,
385 }
386
387 impl BackendReqHandlerImpl {
set_interrupt(&mut self, interrupt: Interrupt)388 fn set_interrupt(&mut self, interrupt: Interrupt) {
389 self.interrupt = Some(interrupt);
390 }
391
set_shared_mapper_state(&mut self, shared_mapper_state: SharedMapperState)392 fn set_shared_mapper_state(&mut self, shared_mapper_state: SharedMapperState) {
393 self.shared_mapper_state = Some(shared_mapper_state);
394 }
395 }
396
397 impl VhostUserMasterReqHandlerMut for BackendReqHandlerImpl {
shmem_map( &mut self, req: &VhostUserShmemMapMsg, fd: &dyn AsRawDescriptor, ) -> HandlerResult<u64>398 fn shmem_map(
399 &mut self,
400 req: &VhostUserShmemMapMsg,
401 fd: &dyn AsRawDescriptor,
402 ) -> HandlerResult<u64> {
403 let shared_mapper_state = self
404 .shared_mapper_state
405 .as_mut()
406 .ok_or_else(|| std::io::Error::from_raw_os_error(libc::EINVAL))?;
407 if req.shmid != shared_mapper_state.shmid {
408 error!(
409 "bad shmid {}, expected {}",
410 req.shmid, shared_mapper_state.shmid
411 );
412 return Err(std::io::Error::from_raw_os_error(libc::EINVAL));
413 }
414 match shared_mapper_state.mapper.add_mapping(
415 VmMemorySource::Descriptor {
416 descriptor: SafeDescriptor::try_from(fd)
417 .map_err(|_| std::io::Error::from_raw_os_error(libc::EIO))?,
418 offset: req.fd_offset,
419 size: req.len,
420 },
421 req.shm_offset,
422 Protection::from(req.flags),
423 ) {
424 Ok(()) => Ok(0),
425 Err(e) => {
426 error!("failed to create mapping {:?}", e);
427 Err(std::io::Error::from_raw_os_error(libc::EINVAL))
428 }
429 }
430 }
431
shmem_unmap(&mut self, req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64>432 fn shmem_unmap(&mut self, req: &VhostUserShmemUnmapMsg) -> HandlerResult<u64> {
433 let shared_mapper_state = self
434 .shared_mapper_state
435 .as_mut()
436 .ok_or_else(|| std::io::Error::from_raw_os_error(libc::EINVAL))?;
437 if req.shmid != shared_mapper_state.shmid {
438 error!(
439 "bad shmid {}, expected {}",
440 req.shmid, shared_mapper_state.shmid
441 );
442 return Err(std::io::Error::from_raw_os_error(libc::EINVAL));
443 }
444 match shared_mapper_state.mapper.remove_mapping(req.shm_offset) {
445 Ok(()) => Ok(0),
446 Err(e) => {
447 error!("failed to remove mapping {:?}", e);
448 Err(std::io::Error::from_raw_os_error(libc::EINVAL))
449 }
450 }
451 }
452
gpu_map( &mut self, req: &VhostUserGpuMapMsg, descriptor: &dyn AsRawDescriptor, ) -> HandlerResult<u64>453 fn gpu_map(
454 &mut self,
455 req: &VhostUserGpuMapMsg,
456 descriptor: &dyn AsRawDescriptor,
457 ) -> HandlerResult<u64> {
458 let shared_mapper_state = self
459 .shared_mapper_state
460 .as_mut()
461 .ok_or_else(|| std::io::Error::from_raw_os_error(libc::EINVAL))?;
462 if req.shmid != shared_mapper_state.shmid {
463 error!(
464 "bad shmid {}, expected {}",
465 req.shmid, shared_mapper_state.shmid
466 );
467 return Err(std::io::Error::from_raw_os_error(libc::EINVAL));
468 }
469 match shared_mapper_state.mapper.add_mapping(
470 VmMemorySource::Vulkan {
471 descriptor: SafeDescriptor::try_from(descriptor)
472 .map_err(|_| std::io::Error::from_raw_os_error(libc::EIO))?,
473 handle_type: req.handle_type,
474 memory_idx: req.memory_idx,
475 device_id: DeviceId {
476 device_uuid: req.device_uuid,
477 driver_uuid: req.driver_uuid,
478 },
479 size: req.len,
480 },
481 req.shm_offset,
482 Protection::read_write(),
483 ) {
484 Ok(()) => Ok(0),
485 Err(e) => {
486 error!("failed to create mapping {:?}", e);
487 Err(std::io::Error::from_raw_os_error(libc::EINVAL))
488 }
489 }
490 }
491
handle_config_change(&mut self) -> HandlerResult<u64>492 fn handle_config_change(&mut self) -> HandlerResult<u64> {
493 info!("Handle Config Change called");
494 match &self.interrupt {
495 Some(interrupt) => {
496 interrupt.signal_config_changed();
497 Ok(0)
498 }
499 None => {
500 error!("cannot send interrupt");
501 Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
502 }
503 }
504 }
505 }
506