1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cell::RefCell;
6 use std::collections::BTreeMap;
7 use std::fs::File;
8 use std::io::{self, Write};
9 use std::mem::size_of;
10 use std::ops::RangeInclusive;
11 use std::rc::Rc;
12 use std::sync::Arc;
13 use std::{result, thread};
14
15 use acpi_tables::sdt::SDT;
16 use anyhow::Context;
17 use base::{
18 error, pagesize, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor,
19 Result as SysResult, Tube, TubeError,
20 };
21 use cros_async::{AsyncError, AsyncTube, EventAsync, Executor};
22 use data_model::{DataInit, Le64};
23 use futures::{select, FutureExt};
24 use remain::sorted;
25 use sync::Mutex;
26 use thiserror::Error;
27 use vm_control::{
28 VirtioIOMMURequest, VirtioIOMMUResponse, VirtioIOMMUVfioCommand, VirtioIOMMUVfioResult,
29 };
30 use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError};
31
32 use crate::pci::PciAddress;
33 use crate::virtio::{
34 async_utils, copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader,
35 SignalableInterrupt, VirtioDevice, Writer, TYPE_IOMMU,
36 };
37 use crate::VfioContainer;
38
39 pub mod protocol;
40 use crate::virtio::iommu::protocol::*;
41 pub mod ipc_memory_mapper;
42 use crate::virtio::iommu::ipc_memory_mapper::*;
43 pub mod memory_mapper;
44 pub mod memory_util;
45 pub mod vfio_wrapper;
46 use crate::virtio::iommu::memory_mapper::{Error as MemoryMapperError, *};
47
48 use self::vfio_wrapper::VfioWrapper;
49
50 const QUEUE_SIZE: u16 = 256;
51 const NUM_QUEUES: usize = 2;
52 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
53
54 // Size of struct virtio_iommu_probe_property
55 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
56 const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
57
58 const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
59 const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
60
61 #[derive(Copy, Clone, Debug, Default)]
62 #[repr(C, packed)]
63 struct VirtioIommuViotHeader {
64 node_count: u16,
65 node_offset: u16,
66 reserved: [u8; 8],
67 }
68
69 // Safe because it only has data and has no implicit padding.
70 unsafe impl DataInit for VirtioIommuViotHeader {}
71
72 #[derive(Copy, Clone, Debug, Default)]
73 #[repr(C, packed)]
74 struct VirtioIommuViotVirtioPciNode {
75 type_: u8,
76 reserved: [u8; 1],
77 length: u16,
78 segment: u16,
79 bdf: u16,
80 reserved2: [u8; 8],
81 }
82
83 // Safe because it only has data and has no implicit padding.
84 unsafe impl DataInit for VirtioIommuViotVirtioPciNode {}
85
86 #[derive(Copy, Clone, Debug, Default)]
87 #[repr(C, packed)]
88 struct VirtioIommuViotPciRangeNode {
89 type_: u8,
90 reserved: [u8; 1],
91 length: u16,
92 endpoint_start: u32,
93 segment_start: u16,
94 segment_end: u16,
95 bdf_start: u16,
96 bdf_end: u16,
97 output_node: u16,
98 reserved2: [u8; 2],
99 reserved3: [u8; 4],
100 }
101
102 // Safe because it only has data and has no implicit padding.
103 unsafe impl DataInit for VirtioIommuViotPciRangeNode {}
104
105 type Result<T> = result::Result<T, IommuError>;
106
107 #[sorted]
108 #[derive(Error, Debug)]
109 pub enum IommuError {
110 #[error("async executor error: {0}")]
111 AsyncExec(AsyncError),
112 #[error("failed to create reader: {0}")]
113 CreateReader(DescriptorError),
114 #[error("failed to create wait context: {0}")]
115 CreateWaitContext(SysError),
116 #[error("failed to create writer: {0}")]
117 CreateWriter(DescriptorError),
118 #[error("failed getting host address: {0}")]
119 GetHostAddress(GuestMemoryError),
120 #[error("failed to read from guest address: {0}")]
121 GuestMemoryRead(io::Error),
122 #[error("failed to write to guest address: {0}")]
123 GuestMemoryWrite(io::Error),
124 #[error("memory mapper failed: {0}")]
125 MemoryMapper(MemoryMapperError),
126 #[error("Failed to read descriptor asynchronously: {0}")]
127 ReadAsyncDesc(AsyncError),
128 #[error("failed to read from virtio queue Event: {0}")]
129 ReadQueueEvent(SysError),
130 #[error("tube error: {0}")]
131 Tube(TubeError),
132 #[error("unexpected descriptor error")]
133 UnexpectedDescriptor,
134 #[error("failed to receive virtio-iommu control request: {0}")]
135 VirtioIOMMUReqError(TubeError),
136 #[error("failed to send virtio-iommu control response: {0}")]
137 VirtioIOMMUResponseError(TubeError),
138 #[error("failed to wait for events: {0}")]
139 WaitError(SysError),
140 #[error("write buffer length too small")]
141 WriteBufferTooSmall,
142 }
143
144 struct Worker {
145 mem: GuestMemory,
146 page_mask: u64,
147 // Hot-pluggable PCI endpoints ranges
148 // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
149 hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
150 // All PCI endpoints that attach to certain IOMMU domain
151 // key: endpoint PCI address
152 // value: attached domain ID
153 endpoint_map: BTreeMap<u32, u32>,
154 // All attached domains
155 // key: domain ID
156 // value: reference counter and MemoryMapperTrait
157 domain_map: BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>,
158 }
159
160 impl Worker {
161 // Remove the endpoint from the endpoint_map and
162 // decrement the reference counter (or remove the entry if the ref count is 1)
163 // from domain_map
detach_endpoint(&mut self, endpoint: u32)164 fn detach_endpoint(&mut self, endpoint: u32) {
165 // The endpoint has attached to an IOMMU domain
166 if let Some(attached_domain) = self.endpoint_map.get(&endpoint) {
167 // Remove the entry or update the domain reference count
168 if let Some(dm_val) = self.domain_map.get(attached_domain) {
169 match dm_val.0 {
170 0 => unreachable!(),
171 1 => self.domain_map.remove(attached_domain),
172 _ => {
173 let new_refs = dm_val.0 - 1;
174 let vfio = dm_val.1.clone();
175 self.domain_map.insert(*attached_domain, (new_refs, vfio))
176 }
177 };
178 }
179 }
180
181 self.endpoint_map.remove(&endpoint);
182 }
183
184 // Notes: if a VFIO group contains multiple devices, it could violate the follow
185 // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
186 // is negotiated, all accesses from unattached endpoints are allowed and translated
187 // by the IOMMU using the identity function. If the feature is not negotiated, any
188 // memory access from an unattached endpoint fails.
189 //
190 // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
191 // request for the first endpoint in a VFIO group, any not yet attached endpoints
192 // in the VFIO group will be able to access the domain.
193 //
194 // This violation is benign for current virtualization use cases. Since device
195 // topology in the guest matches topology in the host, the guest doesn't expect
196 // the device in the same VFIO group are isolated from each other in the first place.
process_attach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>197 fn process_attach_request(
198 &mut self,
199 reader: &mut Reader,
200 tail: &mut virtio_iommu_req_tail,
201 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
202 ) -> Result<usize> {
203 let req: virtio_iommu_req_attach =
204 reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
205
206 // If the reserved field of an ATTACH request is not zero,
207 // the device MUST reject the request and set status to
208 // VIRTIO_IOMMU_S_INVAL.
209 if req.reserved.iter().any(|&x| x != 0) {
210 tail.status = VIRTIO_IOMMU_S_INVAL;
211 return Ok(0);
212 }
213
214 // If the endpoint identified by endpoint doesn’t exist,
215 // the device MUST reject the request and set status to
216 // VIRTIO_IOMMU_S_NOENT.
217 let domain: u32 = req.domain.into();
218 let endpoint: u32 = req.endpoint.into();
219 if !endpoints.borrow().contains_key(&endpoint) {
220 tail.status = VIRTIO_IOMMU_S_NOENT;
221 return Ok(0);
222 }
223
224 // If the endpoint identified by endpoint is already attached
225 // to another domain, then the device SHOULD first detach it
226 // from that domain and attach it to the one identified by domain.
227 if self.endpoint_map.contains_key(&endpoint) {
228 self.detach_endpoint(endpoint);
229 }
230
231 if let Some(vfio_container) = endpoints.borrow_mut().get(&endpoint) {
232 let new_ref = match self.domain_map.get(&domain) {
233 None => 1,
234 Some(val) => val.0 + 1,
235 };
236
237 self.endpoint_map.insert(endpoint, domain);
238 self.domain_map
239 .insert(domain, (new_ref, vfio_container.clone()));
240 }
241
242 Ok(0)
243 }
244
process_dma_map_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>245 fn process_dma_map_request(
246 &mut self,
247 reader: &mut Reader,
248 tail: &mut virtio_iommu_req_tail,
249 ) -> Result<usize> {
250 let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
251
252 // If virt_start, phys_start or (virt_end + 1) is not aligned
253 // on the page granularity, the device SHOULD reject the
254 // request and set status to VIRTIO_IOMMU_S_RANGE
255 if self.page_mask & u64::from(req.phys_start) != 0
256 || self.page_mask & u64::from(req.virt_start) != 0
257 || self.page_mask & (u64::from(req.virt_end) + 1) != 0
258 {
259 tail.status = VIRTIO_IOMMU_S_RANGE;
260 return Ok(0);
261 }
262
263 // If the device doesn’t recognize a flags bit, it MUST reject
264 // the request and set status to VIRTIO_IOMMU_S_INVAL.
265 if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
266 tail.status = VIRTIO_IOMMU_S_INVAL;
267 return Ok(0);
268 }
269
270 let domain: u32 = req.domain.into();
271 if !self.domain_map.contains_key(&domain) {
272 // If domain does not exist, the device SHOULD reject
273 // the request and set status to VIRTIO_IOMMU_S_NOENT.
274 tail.status = VIRTIO_IOMMU_S_NOENT;
275 return Ok(0);
276 }
277
278 // The device MUST NOT allow writes to a range mapped
279 // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
280 let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
281
282 if let Some(mapper) = self.domain_map.get(&domain) {
283 let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1u64;
284
285 let vfio_map_result = mapper.1.lock().add_map(MappingInfo {
286 iova: req.virt_start.into(),
287 gpa: GuestAddress(req.phys_start.into()),
288 size,
289 perm: match write_en {
290 true => Permission::RW,
291 false => Permission::Read,
292 },
293 });
294
295 match vfio_map_result {
296 Ok(()) => (),
297 Err(e) => match e {
298 MemoryMapperError::IovaRegionOverlap => {
299 // If a mapping already exists in the requested range,
300 // the device SHOULD reject the request and set status
301 // to VIRTIO_IOMMU_S_INVAL.
302 tail.status = VIRTIO_IOMMU_S_INVAL;
303 return Ok(0);
304 }
305 _ => return Err(IommuError::MemoryMapper(e)),
306 },
307 }
308 }
309
310 Ok(0)
311 }
312
process_dma_unmap_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>313 fn process_dma_unmap_request(
314 &mut self,
315 reader: &mut Reader,
316 tail: &mut virtio_iommu_req_tail,
317 ) -> Result<usize> {
318 let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
319
320 let domain: u32 = req.domain.into();
321 if let Some(mapper) = self.domain_map.get(&domain) {
322 let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
323 mapper
324 .1
325 .lock()
326 .remove_map(u64::from(req.virt_start), size)
327 .map_err(IommuError::MemoryMapper)?;
328 } else {
329 // If domain does not exist, the device SHOULD set the
330 // request status to VIRTIO_IOMMU_S_NOENT
331 tail.status = VIRTIO_IOMMU_S_NOENT;
332 }
333
334 Ok(0)
335 }
336
337 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
process_probe_request( &mut self, reader: &mut Reader, writer: &mut Writer, tail: &mut virtio_iommu_req_tail, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>338 fn process_probe_request(
339 &mut self,
340 reader: &mut Reader,
341 writer: &mut Writer,
342 tail: &mut virtio_iommu_req_tail,
343 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
344 ) -> Result<usize> {
345 let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
346 let endpoint: u32 = req.endpoint.into();
347
348 // If the endpoint identified by endpoint doesn’t exist,
349 // then the device SHOULD reject the request and set status
350 // to VIRTIO_IOMMU_S_NOENT.
351 if !endpoints.borrow().contains_key(&endpoint) {
352 tail.status = VIRTIO_IOMMU_S_NOENT;
353 }
354
355 let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
356
357 // It's OK if properties_size is larger than probe_size
358 // We are good even if properties_size is 0
359 if properties_size < IOMMU_PROBE_SIZE {
360 // If the properties list is smaller than probe_size, the device
361 // SHOULD NOT write any property. It SHOULD reject the request
362 // and set status to VIRTIO_IOMMU_S_INVAL.
363 tail.status = VIRTIO_IOMMU_S_INVAL;
364 } else if tail.status == VIRTIO_IOMMU_S_OK {
365 const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
366 const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
367 const PROBE_PROPERTY_SIZE: u16 = 4;
368 const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
369 const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
370
371 let properties = virtio_iommu_probe_resv_mem {
372 head: virtio_iommu_probe_property {
373 type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
374 length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
375 },
376 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
377 start: X86_MSI_IOVA_START.into(),
378 end: X86_MSI_IOVA_END.into(),
379 ..Default::default()
380 };
381 writer
382 .write_all(properties.as_slice())
383 .map_err(IommuError::GuestMemoryWrite)?;
384 }
385
386 // If the device doesn’t fill all probe_size bytes with properties,
387 // it SHOULD fill the remaining bytes of properties with zeroes.
388 let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
389
390 if remaining_bytes > 0 {
391 let buffer: Vec<u8> = vec![0; remaining_bytes];
392 writer
393 .write_all(buffer.as_slice())
394 .map_err(IommuError::GuestMemoryWrite)?;
395 }
396
397 Ok(properties_size)
398 }
399
execute_request( &mut self, avail_desc: &DescriptorChain, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>400 fn execute_request(
401 &mut self,
402 avail_desc: &DescriptorChain,
403 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
404 ) -> Result<usize> {
405 let mut reader =
406 Reader::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateReader)?;
407 let mut writer =
408 Writer::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateWriter)?;
409
410 // at least we need space to write VirtioIommuReqTail
411 if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
412 return Err(IommuError::WriteBufferTooSmall);
413 }
414
415 let req_head: virtio_iommu_req_head =
416 reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
417
418 let mut tail = virtio_iommu_req_tail {
419 status: VIRTIO_IOMMU_S_OK,
420 ..Default::default()
421 };
422
423 let reply_len = match req_head.type_ {
424 VIRTIO_IOMMU_T_ATTACH => {
425 self.process_attach_request(&mut reader, &mut tail, endpoints)?
426 }
427 VIRTIO_IOMMU_T_DETACH => {
428 // A few reasons why we don't support VIRTIO_IOMMU_T_DETACH for now:
429 //
430 // 1. Linux virtio IOMMU front-end driver doesn't implement VIRTIO_IOMMU_T_DETACH request
431 // 2. Seems it's not possible to dynamically attach and detach a IOMMU domain if the
432 // virtio IOMMU device is running on top of VFIO
433 // 3. Even if VIRTIO_IOMMU_T_DETACH is implemented in front-end driver, it could violate
434 // the following virtio IOMMU spec: Detach an endpoint from a domain. when this request
435 // completes, the endpoint cannot access any mapping from that domain anymore.
436 //
437 // This is because VFIO doesn't support detaching a single device. When the virtio-iommu
438 // device receives a VIRTIO_IOMMU_T_DETACH request, it can either to:
439 // - detach a group: any other endpoints in the group lose access to the domain.
440 // - do not detach the group at all: this breaks the above mentioned spec.
441 tail.status = VIRTIO_IOMMU_S_UNSUPP;
442 0
443 }
444 VIRTIO_IOMMU_T_MAP => self.process_dma_map_request(&mut reader, &mut tail)?,
445 VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(&mut reader, &mut tail)?,
446 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
447 VIRTIO_IOMMU_T_PROBE => {
448 self.process_probe_request(&mut reader, &mut writer, &mut tail, endpoints)?
449 }
450 _ => return Err(IommuError::UnexpectedDescriptor),
451 };
452
453 writer
454 .write_all(tail.as_slice())
455 .map_err(IommuError::GuestMemoryWrite)?;
456 Ok((reply_len as usize) + size_of::<virtio_iommu_req_tail>())
457 }
458
request_queue<I: SignalableInterrupt>( &mut self, mut queue: Queue, mut queue_event: EventAsync, interrupt: &I, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<()>459 async fn request_queue<I: SignalableInterrupt>(
460 &mut self,
461 mut queue: Queue,
462 mut queue_event: EventAsync,
463 interrupt: &I,
464 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
465 ) -> Result<()> {
466 loop {
467 let avail_desc = queue
468 .next_async(&self.mem, &mut queue_event)
469 .await
470 .map_err(IommuError::ReadAsyncDesc)?;
471 let desc_index = avail_desc.index;
472
473 let len = match self.execute_request(&avail_desc, endpoints) {
474 Ok(len) => len as u32,
475 Err(e) => {
476 error!("execute_request failed: {}", e);
477
478 // If a request type is not recognized, the device SHOULD NOT write
479 // the buffer and SHOULD set the used length to zero
480 0
481 }
482 };
483
484 queue.add_used(&self.mem, desc_index, len as u32);
485 queue.trigger_interrupt(&self.mem, interrupt);
486 }
487 }
488
handle_add_vfio_device( mem: &GuestMemory, endpoint_addr: u32, container_fd: File, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> VirtioIOMMUVfioResult489 fn handle_add_vfio_device(
490 mem: &GuestMemory,
491 endpoint_addr: u32,
492 container_fd: File,
493 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
494 hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
495 ) -> VirtioIOMMUVfioResult {
496 let exists = |endpoint_addr: u32| -> bool {
497 for endpoints_range in hp_endpoints_ranges.iter() {
498 if endpoints_range.contains(&endpoint_addr) {
499 return true;
500 }
501 }
502 false
503 };
504
505 if !exists(endpoint_addr) {
506 return VirtioIOMMUVfioResult::NotInPCIRanges;
507 }
508
509 let vfio_container = match VfioContainer::new_from_container(container_fd) {
510 Ok(vfio_container) => vfio_container,
511 Err(e) => {
512 error!("failed to verify the new container: {}", e);
513 return VirtioIOMMUVfioResult::NoAvailableContainer;
514 }
515 };
516 endpoints.borrow_mut().insert(
517 endpoint_addr,
518 Arc::new(Mutex::new(Box::new(VfioWrapper::new(
519 Arc::new(Mutex::new(vfio_container)),
520 mem.clone(),
521 )))),
522 );
523 VirtioIOMMUVfioResult::Ok
524 }
525
handle_del_vfio_device( pci_address: u32, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> VirtioIOMMUVfioResult526 fn handle_del_vfio_device(
527 pci_address: u32,
528 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
529 ) -> VirtioIOMMUVfioResult {
530 if endpoints.borrow_mut().remove(&pci_address).is_none() {
531 error!("There is no vfio container of {}", pci_address);
532 return VirtioIOMMUVfioResult::NoSuchDevice;
533 }
534 VirtioIOMMUVfioResult::Ok
535 }
536
handle_vfio( mem: &GuestMemory, vfio_cmd: VirtioIOMMUVfioCommand, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> VirtioIOMMUResponse537 fn handle_vfio(
538 mem: &GuestMemory,
539 vfio_cmd: VirtioIOMMUVfioCommand,
540 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
541 hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
542 ) -> VirtioIOMMUResponse {
543 use VirtioIOMMUVfioCommand::*;
544 let vfio_result = match vfio_cmd {
545 VfioDeviceAdd {
546 endpoint_addr,
547 container,
548 } => Self::handle_add_vfio_device(
549 mem,
550 endpoint_addr,
551 container,
552 endpoints,
553 hp_endpoints_ranges,
554 ),
555 VfioDeviceDel { endpoint_addr } => {
556 Self::handle_del_vfio_device(endpoint_addr, endpoints)
557 }
558 };
559 VirtioIOMMUResponse::VfioResponse(vfio_result)
560 }
561
562 // Async task that handles messages from the host
handle_command_tube( mem: &GuestMemory, command_tube: AsyncTube, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> Result<()>563 async fn handle_command_tube(
564 mem: &GuestMemory,
565 command_tube: AsyncTube,
566 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
567 hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
568 ) -> Result<()> {
569 loop {
570 match command_tube.next::<VirtioIOMMURequest>().await {
571 Ok(command) => {
572 let response: VirtioIOMMUResponse = match command {
573 VirtioIOMMURequest::VfioCommand(vfio_cmd) => {
574 Self::handle_vfio(mem, vfio_cmd, endpoints, hp_endpoints_ranges)
575 }
576 };
577 if let Err(e) = command_tube.send(response).await {
578 error!("{}", IommuError::VirtioIOMMUResponseError(e));
579 }
580 }
581 Err(e) => {
582 return Err(IommuError::VirtioIOMMUReqError(e));
583 }
584 }
585 }
586 }
587
run( &mut self, iommu_device_tube: Tube, mut queues: Vec<Queue>, queue_evts: Vec<Event>, kill_evt: Event, interrupt: Interrupt, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, ) -> Result<()>588 fn run(
589 &mut self,
590 iommu_device_tube: Tube,
591 mut queues: Vec<Queue>,
592 queue_evts: Vec<Event>,
593 kill_evt: Event,
594 interrupt: Interrupt,
595 endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
596 translate_response_senders: Option<BTreeMap<u32, Tube>>,
597 translate_request_rx: Option<Tube>,
598 ) -> Result<()> {
599 let ex = Executor::new().expect("Failed to create an executor");
600
601 let mut evts_async: Vec<EventAsync> = queue_evts
602 .into_iter()
603 .map(|e| EventAsync::new(e.0, &ex).expect("Failed to create async event for queue"))
604 .collect();
605 let interrupt = Rc::new(RefCell::new(interrupt));
606 let interrupt_ref = &*interrupt.borrow();
607
608 let (req_queue, req_evt) = (queues.remove(0), evts_async.remove(0));
609
610 let hp_endpoints_ranges = Rc::new(self.hp_endpoints_ranges.clone());
611 let mem = Rc::new(self.mem.clone());
612 // contains all pass-through endpoints that attach to this IOMMU device
613 // key: endpoint PCI address
614 // value: reference counter and MemoryMapperTrait
615 let endpoints: Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>> =
616 Rc::new(RefCell::new(endpoints));
617
618 let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
619 let f_kill = async_utils::await_and_exit(&ex, kill_evt);
620
621 let request_tube = translate_request_rx
622 .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
623 let response_tubes = translate_response_senders.map(|m| {
624 m.into_iter()
625 .map(|x| {
626 (
627 x.0,
628 AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
629 )
630 })
631 .collect()
632 });
633
634 let f_handle_translate_request =
635 handle_translate_request(&endpoints, request_tube, response_tubes);
636 let f_request = self.request_queue(req_queue, req_evt, interrupt_ref, &endpoints);
637
638 let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
639 // Future to handle command messages from host, such as passing vfio containers.
640 let f_cmd = Self::handle_command_tube(&mem, command_tube, &endpoints, &hp_endpoints_ranges);
641
642 let done = async {
643 select! {
644 res = f_request.fuse() => res.context("error in handling request queue"),
645 res = f_resample.fuse() => res.context("error in handle_irq_resample"),
646 res = f_kill.fuse() => res.context("error in await_and_exit"),
647 res = f_handle_translate_request.fuse() => res.context("error in handle_translate_request"),
648 res = f_cmd.fuse() => res.context("error in handling host request"),
649 }
650 };
651 match ex.run_until(done) {
652 Ok(Ok(())) => {}
653 Ok(Err(e)) => error!("Error in worker: {}", e),
654 Err(e) => return Err(IommuError::AsyncExec(e)),
655 }
656
657 Ok(())
658 }
659 }
660
handle_translate_request( endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, request_tube: Option<AsyncTube>, response_tubes: Option<BTreeMap<u32, AsyncTube>>, ) -> Result<()>661 async fn handle_translate_request(
662 endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
663 request_tube: Option<AsyncTube>,
664 response_tubes: Option<BTreeMap<u32, AsyncTube>>,
665 ) -> Result<()> {
666 let request_tube = match request_tube {
667 Some(r) => r,
668 None => {
669 let () = futures::future::pending().await;
670 return Ok(());
671 }
672 };
673 let response_tubes = response_tubes.unwrap();
674 loop {
675 let TranslateRequest {
676 endpoint_id,
677 iova,
678 size,
679 } = request_tube.next().await.map_err(IommuError::Tube)?;
680 if let Some(mapper) = endpoints.borrow_mut().get(&endpoint_id) {
681 response_tubes
682 .get(&endpoint_id)
683 .unwrap()
684 .send(
685 mapper
686 .lock()
687 .translate(iova, size)
688 .map_err(|e| {
689 error!("Failed to handle TranslateRequest: {}", e);
690 e
691 })
692 .ok(),
693 )
694 .await
695 .map_err(IommuError::Tube)?;
696 } else {
697 error!("endpoint_id {} not found", endpoint_id)
698 }
699 }
700 }
701
702 /// Virtio device for IOMMU memory management.
703 pub struct Iommu {
704 kill_evt: Option<Event>,
705 worker_thread: Option<thread::JoinHandle<Worker>>,
706 config: virtio_iommu_config,
707 avail_features: u64,
708 // Attached endpoints
709 // key: endpoint PCI address
710 // value: reference counter and MemoryMapperTrait
711 endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
712 // Hot-pluggable PCI endpoints ranges
713 // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
714 hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
715 translate_response_senders: Option<BTreeMap<u32, Tube>>,
716 translate_request_rx: Option<Tube>,
717 iommu_device_tube: Option<Tube>,
718 }
719
720 impl Iommu {
721 /// Create a new virtio IOMMU device.
new( base_features: u64, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, phys_max_addr: u64, hp_endpoints_ranges: Vec<RangeInclusive<u32>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, iommu_device_tube: Option<Tube>, ) -> SysResult<Iommu>722 pub fn new(
723 base_features: u64,
724 endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
725 phys_max_addr: u64,
726 hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
727 translate_response_senders: Option<BTreeMap<u32, Tube>>,
728 translate_request_rx: Option<Tube>,
729 iommu_device_tube: Option<Tube>,
730 ) -> SysResult<Iommu> {
731 let mut page_size_mask = !0_u64;
732 for (_, container) in endpoints.iter() {
733 page_size_mask &= container
734 .lock()
735 .get_mask()
736 .map_err(|_e| SysError::new(libc::EIO))?;
737 }
738
739 if page_size_mask == 0 {
740 // In case no endpoints bounded to vIOMMU during system booting,
741 // assume IOVA page size is the same as page_size
742 page_size_mask = (pagesize() as u64) - 1;
743 }
744
745 let input_range = virtio_iommu_range_64 {
746 start: Le64::from(0),
747 end: phys_max_addr.into(),
748 };
749
750 let config = virtio_iommu_config {
751 page_size_mask: page_size_mask.into(),
752 input_range,
753 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
754 probe_size: (IOMMU_PROBE_SIZE as u32).into(),
755 ..Default::default()
756 };
757
758 let mut avail_features: u64 = base_features;
759 avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP | 1 << VIRTIO_IOMMU_F_INPUT_RANGE;
760
761 if cfg!(any(target_arch = "x86", target_arch = "x86_64")) {
762 avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
763 }
764
765 Ok(Iommu {
766 kill_evt: None,
767 worker_thread: None,
768 config,
769 avail_features,
770 endpoints,
771 hp_endpoints_ranges,
772 translate_response_senders,
773 translate_request_rx,
774 iommu_device_tube,
775 })
776 }
777 }
778
779 impl Drop for Iommu {
drop(&mut self)780 fn drop(&mut self) {
781 if let Some(kill_evt) = self.kill_evt.take() {
782 let _ = kill_evt.write(1);
783 }
784
785 if let Some(worker_thread) = self.worker_thread.take() {
786 let _ = worker_thread.join();
787 }
788 }
789 }
790
791 impl VirtioDevice for Iommu {
keep_rds(&self) -> Vec<RawDescriptor>792 fn keep_rds(&self) -> Vec<RawDescriptor> {
793 let mut rds = Vec::new();
794
795 for (_, mapper) in self.endpoints.iter() {
796 rds.append(&mut mapper.lock().as_raw_descriptors());
797 }
798 if let Some(senders) = &self.translate_response_senders {
799 for (_, tube) in senders.iter() {
800 rds.push(tube.as_raw_descriptor());
801 }
802 }
803 if let Some(rx) = &self.translate_request_rx {
804 rds.push(rx.as_raw_descriptor());
805 }
806
807 if let Some(iommu_device_tube) = &self.iommu_device_tube {
808 rds.push(iommu_device_tube.as_raw_descriptor());
809 }
810
811 rds
812 }
813
device_type(&self) -> u32814 fn device_type(&self) -> u32 {
815 TYPE_IOMMU
816 }
817
queue_max_sizes(&self) -> &[u16]818 fn queue_max_sizes(&self) -> &[u16] {
819 QUEUE_SIZES
820 }
821
features(&self) -> u64822 fn features(&self) -> u64 {
823 self.avail_features
824 }
825
read_config(&self, offset: u64, data: &mut [u8])826 fn read_config(&self, offset: u64, data: &mut [u8]) {
827 let mut config: Vec<u8> = Vec::new();
828 config.extend_from_slice(self.config.as_slice());
829 copy_config(data, 0, config.as_slice(), offset);
830 }
831
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: Vec<Queue>, queue_evts: Vec<Event>, )832 fn activate(
833 &mut self,
834 mem: GuestMemory,
835 interrupt: Interrupt,
836 queues: Vec<Queue>,
837 queue_evts: Vec<Event>,
838 ) {
839 if queues.len() != QUEUE_SIZES.len() || queue_evts.len() != QUEUE_SIZES.len() {
840 return;
841 }
842
843 let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
844 Ok(v) => v,
845 Err(e) => {
846 error!("failed to create kill Event pair: {}", e);
847 return;
848 }
849 };
850 self.kill_evt = Some(self_kill_evt);
851
852 // The least significant bit of page_size_masks defines the page
853 // granularity of IOMMU mappings
854 let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
855 let eps = self.endpoints.clone();
856 let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
857
858 let translate_response_senders = self.translate_response_senders.take();
859 let translate_request_rx = self.translate_request_rx.take();
860
861 match self.iommu_device_tube.take() {
862 Some(iommu_device_tube) => {
863 let worker_result = thread::Builder::new()
864 .name("virtio_iommu".to_string())
865 .spawn(move || {
866 let mut worker = Worker {
867 mem,
868 page_mask,
869 hp_endpoints_ranges,
870 endpoint_map: BTreeMap::new(),
871 domain_map: BTreeMap::new(),
872 };
873 let result = worker.run(
874 iommu_device_tube,
875 queues,
876 queue_evts,
877 kill_evt,
878 interrupt,
879 eps,
880 translate_response_senders,
881 translate_request_rx,
882 );
883 if let Err(e) = result {
884 error!("virtio-iommu worker thread exited with error: {}", e);
885 }
886 worker
887 });
888
889 match worker_result {
890 Err(e) => error!("failed to spawn virtio_iommu worker thread: {}", e),
891 Ok(join_handle) => self.worker_thread = Some(join_handle),
892 }
893 }
894 None => {
895 error!("failed to start virtio-iommu worker: No control tube");
896 return;
897 }
898 }
899 }
900
901 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
generate_acpi( &mut self, pci_address: &Option<PciAddress>, mut sdts: Vec<SDT>, ) -> Option<Vec<SDT>>902 fn generate_acpi(
903 &mut self,
904 pci_address: &Option<PciAddress>,
905 mut sdts: Vec<SDT>,
906 ) -> Option<Vec<SDT>> {
907 const OEM_REVISION: u32 = 1;
908 const VIOT_REVISION: u8 = 0;
909
910 for sdt in sdts.iter() {
911 // there should only be one VIOT table
912 if sdt.is_signature(b"VIOT") {
913 warn!("vIOMMU: duplicate VIOT table detected");
914 return None;
915 }
916 }
917
918 let mut viot = SDT::new(
919 *b"VIOT",
920 acpi_tables::HEADER_LEN,
921 VIOT_REVISION,
922 *b"CROSVM",
923 *b"CROSVMDT",
924 OEM_REVISION,
925 );
926 viot.append(VirtioIommuViotHeader {
927 // # of PCI range nodes + 1 virtio-pci node
928 node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
929 node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
930 ..Default::default()
931 });
932
933 let bdf = pci_address
934 .or_else(|| {
935 error!("vIOMMU device has no PCI address");
936 None
937 })?
938 .to_u32() as u16;
939 let iommu_offset = viot.len();
940
941 viot.append(VirtioIommuViotVirtioPciNode {
942 type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
943 length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
944 bdf,
945 ..Default::default()
946 });
947
948 for (endpoint, _) in self.endpoints.iter() {
949 viot.append(VirtioIommuViotPciRangeNode {
950 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
951 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
952 endpoint_start: *endpoint,
953 bdf_start: *endpoint as u16,
954 bdf_end: *endpoint as u16,
955 output_node: iommu_offset as u16,
956 ..Default::default()
957 });
958 }
959
960 for endpoints_range in self.hp_endpoints_ranges.iter() {
961 let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
962 viot.append(VirtioIommuViotPciRangeNode {
963 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
964 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
965 endpoint_start,
966 bdf_start: endpoint_start as u16,
967 bdf_end: endpoint_end as u16,
968 output_node: iommu_offset as u16,
969 ..Default::default()
970 });
971 }
972
973 sdts.push(viot);
974 Some(sdts)
975 }
976 }
977