• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cell::RefCell;
6 use std::collections::BTreeMap;
7 use std::fs::File;
8 use std::io::{self, Write};
9 use std::mem::size_of;
10 use std::ops::RangeInclusive;
11 use std::rc::Rc;
12 use std::sync::Arc;
13 use std::{result, thread};
14 
15 use acpi_tables::sdt::SDT;
16 use anyhow::Context;
17 use base::{
18     error, pagesize, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor,
19     Result as SysResult, Tube, TubeError,
20 };
21 use cros_async::{AsyncError, AsyncTube, EventAsync, Executor};
22 use data_model::{DataInit, Le64};
23 use futures::{select, FutureExt};
24 use remain::sorted;
25 use sync::Mutex;
26 use thiserror::Error;
27 use vm_control::{
28     VirtioIOMMURequest, VirtioIOMMUResponse, VirtioIOMMUVfioCommand, VirtioIOMMUVfioResult,
29 };
30 use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError};
31 
32 use crate::pci::PciAddress;
33 use crate::virtio::{
34     async_utils, copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader,
35     SignalableInterrupt, VirtioDevice, Writer, TYPE_IOMMU,
36 };
37 use crate::VfioContainer;
38 
39 pub mod protocol;
40 use crate::virtio::iommu::protocol::*;
41 pub mod ipc_memory_mapper;
42 use crate::virtio::iommu::ipc_memory_mapper::*;
43 pub mod memory_mapper;
44 pub mod memory_util;
45 pub mod vfio_wrapper;
46 use crate::virtio::iommu::memory_mapper::{Error as MemoryMapperError, *};
47 
48 use self::vfio_wrapper::VfioWrapper;
49 
50 const QUEUE_SIZE: u16 = 256;
51 const NUM_QUEUES: usize = 2;
52 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
53 
54 // Size of struct virtio_iommu_probe_property
55 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
56 const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
57 
58 const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
59 const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
60 
61 #[derive(Copy, Clone, Debug, Default)]
62 #[repr(C, packed)]
63 struct VirtioIommuViotHeader {
64     node_count: u16,
65     node_offset: u16,
66     reserved: [u8; 8],
67 }
68 
69 // Safe because it only has data and has no implicit padding.
70 unsafe impl DataInit for VirtioIommuViotHeader {}
71 
72 #[derive(Copy, Clone, Debug, Default)]
73 #[repr(C, packed)]
74 struct VirtioIommuViotVirtioPciNode {
75     type_: u8,
76     reserved: [u8; 1],
77     length: u16,
78     segment: u16,
79     bdf: u16,
80     reserved2: [u8; 8],
81 }
82 
83 // Safe because it only has data and has no implicit padding.
84 unsafe impl DataInit for VirtioIommuViotVirtioPciNode {}
85 
86 #[derive(Copy, Clone, Debug, Default)]
87 #[repr(C, packed)]
88 struct VirtioIommuViotPciRangeNode {
89     type_: u8,
90     reserved: [u8; 1],
91     length: u16,
92     endpoint_start: u32,
93     segment_start: u16,
94     segment_end: u16,
95     bdf_start: u16,
96     bdf_end: u16,
97     output_node: u16,
98     reserved2: [u8; 2],
99     reserved3: [u8; 4],
100 }
101 
102 // Safe because it only has data and has no implicit padding.
103 unsafe impl DataInit for VirtioIommuViotPciRangeNode {}
104 
105 type Result<T> = result::Result<T, IommuError>;
106 
107 #[sorted]
108 #[derive(Error, Debug)]
109 pub enum IommuError {
110     #[error("async executor error: {0}")]
111     AsyncExec(AsyncError),
112     #[error("failed to create reader: {0}")]
113     CreateReader(DescriptorError),
114     #[error("failed to create wait context: {0}")]
115     CreateWaitContext(SysError),
116     #[error("failed to create writer: {0}")]
117     CreateWriter(DescriptorError),
118     #[error("failed getting host address: {0}")]
119     GetHostAddress(GuestMemoryError),
120     #[error("failed to read from guest address: {0}")]
121     GuestMemoryRead(io::Error),
122     #[error("failed to write to guest address: {0}")]
123     GuestMemoryWrite(io::Error),
124     #[error("memory mapper failed: {0}")]
125     MemoryMapper(MemoryMapperError),
126     #[error("Failed to read descriptor asynchronously: {0}")]
127     ReadAsyncDesc(AsyncError),
128     #[error("failed to read from virtio queue Event: {0}")]
129     ReadQueueEvent(SysError),
130     #[error("tube error: {0}")]
131     Tube(TubeError),
132     #[error("unexpected descriptor error")]
133     UnexpectedDescriptor,
134     #[error("failed to receive virtio-iommu control request: {0}")]
135     VirtioIOMMUReqError(TubeError),
136     #[error("failed to send virtio-iommu control response: {0}")]
137     VirtioIOMMUResponseError(TubeError),
138     #[error("failed to wait for events: {0}")]
139     WaitError(SysError),
140     #[error("write buffer length too small")]
141     WriteBufferTooSmall,
142 }
143 
144 struct Worker {
145     mem: GuestMemory,
146     page_mask: u64,
147     // Hot-pluggable PCI endpoints ranges
148     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
149     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
150     // All PCI endpoints that attach to certain IOMMU domain
151     // key: endpoint PCI address
152     // value: attached domain ID
153     endpoint_map: BTreeMap<u32, u32>,
154     // All attached domains
155     // key: domain ID
156     // value: reference counter and MemoryMapperTrait
157     domain_map: BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>,
158 }
159 
160 impl Worker {
161     // Remove the endpoint from the endpoint_map and
162     // decrement the reference counter (or remove the entry if the ref count is 1)
163     // from domain_map
detach_endpoint(&mut self, endpoint: u32)164     fn detach_endpoint(&mut self, endpoint: u32) {
165         // The endpoint has attached to an IOMMU domain
166         if let Some(attached_domain) = self.endpoint_map.get(&endpoint) {
167             // Remove the entry or update the domain reference count
168             if let Some(dm_val) = self.domain_map.get(attached_domain) {
169                 match dm_val.0 {
170                     0 => unreachable!(),
171                     1 => self.domain_map.remove(attached_domain),
172                     _ => {
173                         let new_refs = dm_val.0 - 1;
174                         let vfio = dm_val.1.clone();
175                         self.domain_map.insert(*attached_domain, (new_refs, vfio))
176                     }
177                 };
178             }
179         }
180 
181         self.endpoint_map.remove(&endpoint);
182     }
183 
184     // Notes: if a VFIO group contains multiple devices, it could violate the follow
185     // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
186     // is negotiated, all accesses from unattached endpoints are allowed and translated
187     // by the IOMMU using the identity function. If the feature is not negotiated, any
188     // memory access from an unattached endpoint fails.
189     //
190     // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
191     // request for the first endpoint in a VFIO group, any not yet attached endpoints
192     // in the VFIO group will be able to access the domain.
193     //
194     // This violation is benign for current virtualization use cases. Since device
195     // topology in the guest matches topology in the host, the guest doesn't expect
196     // the device in the same VFIO group are isolated from each other in the first place.
process_attach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>197     fn process_attach_request(
198         &mut self,
199         reader: &mut Reader,
200         tail: &mut virtio_iommu_req_tail,
201         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
202     ) -> Result<usize> {
203         let req: virtio_iommu_req_attach =
204             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
205 
206         // If the reserved field of an ATTACH request is not zero,
207         // the device MUST reject the request and set status to
208         // VIRTIO_IOMMU_S_INVAL.
209         if req.reserved.iter().any(|&x| x != 0) {
210             tail.status = VIRTIO_IOMMU_S_INVAL;
211             return Ok(0);
212         }
213 
214         // If the endpoint identified by endpoint doesn’t exist,
215         // the device MUST reject the request and set status to
216         // VIRTIO_IOMMU_S_NOENT.
217         let domain: u32 = req.domain.into();
218         let endpoint: u32 = req.endpoint.into();
219         if !endpoints.borrow().contains_key(&endpoint) {
220             tail.status = VIRTIO_IOMMU_S_NOENT;
221             return Ok(0);
222         }
223 
224         // If the endpoint identified by endpoint is already attached
225         // to another domain, then the device SHOULD first detach it
226         // from that domain and attach it to the one identified by domain.
227         if self.endpoint_map.contains_key(&endpoint) {
228             self.detach_endpoint(endpoint);
229         }
230 
231         if let Some(vfio_container) = endpoints.borrow_mut().get(&endpoint) {
232             let new_ref = match self.domain_map.get(&domain) {
233                 None => 1,
234                 Some(val) => val.0 + 1,
235             };
236 
237             self.endpoint_map.insert(endpoint, domain);
238             self.domain_map
239                 .insert(domain, (new_ref, vfio_container.clone()));
240         }
241 
242         Ok(0)
243     }
244 
process_dma_map_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>245     fn process_dma_map_request(
246         &mut self,
247         reader: &mut Reader,
248         tail: &mut virtio_iommu_req_tail,
249     ) -> Result<usize> {
250         let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
251 
252         // If virt_start, phys_start or (virt_end + 1) is not aligned
253         // on the page granularity, the device SHOULD reject the
254         // request and set status to VIRTIO_IOMMU_S_RANGE
255         if self.page_mask & u64::from(req.phys_start) != 0
256             || self.page_mask & u64::from(req.virt_start) != 0
257             || self.page_mask & (u64::from(req.virt_end) + 1) != 0
258         {
259             tail.status = VIRTIO_IOMMU_S_RANGE;
260             return Ok(0);
261         }
262 
263         // If the device doesn’t recognize a flags bit, it MUST reject
264         // the request and set status to VIRTIO_IOMMU_S_INVAL.
265         if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
266             tail.status = VIRTIO_IOMMU_S_INVAL;
267             return Ok(0);
268         }
269 
270         let domain: u32 = req.domain.into();
271         if !self.domain_map.contains_key(&domain) {
272             // If domain does not exist, the device SHOULD reject
273             // the request and set status to VIRTIO_IOMMU_S_NOENT.
274             tail.status = VIRTIO_IOMMU_S_NOENT;
275             return Ok(0);
276         }
277 
278         // The device MUST NOT allow writes to a range mapped
279         // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
280         let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
281 
282         if let Some(mapper) = self.domain_map.get(&domain) {
283             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1u64;
284 
285             let vfio_map_result = mapper.1.lock().add_map(MappingInfo {
286                 iova: req.virt_start.into(),
287                 gpa: GuestAddress(req.phys_start.into()),
288                 size,
289                 perm: match write_en {
290                     true => Permission::RW,
291                     false => Permission::Read,
292                 },
293             });
294 
295             match vfio_map_result {
296                 Ok(()) => (),
297                 Err(e) => match e {
298                     MemoryMapperError::IovaRegionOverlap => {
299                         // If a mapping already exists in the requested range,
300                         // the device SHOULD reject the request and set status
301                         // to VIRTIO_IOMMU_S_INVAL.
302                         tail.status = VIRTIO_IOMMU_S_INVAL;
303                         return Ok(0);
304                     }
305                     _ => return Err(IommuError::MemoryMapper(e)),
306                 },
307             }
308         }
309 
310         Ok(0)
311     }
312 
process_dma_unmap_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>313     fn process_dma_unmap_request(
314         &mut self,
315         reader: &mut Reader,
316         tail: &mut virtio_iommu_req_tail,
317     ) -> Result<usize> {
318         let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
319 
320         let domain: u32 = req.domain.into();
321         if let Some(mapper) = self.domain_map.get(&domain) {
322             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
323             mapper
324                 .1
325                 .lock()
326                 .remove_map(u64::from(req.virt_start), size)
327                 .map_err(IommuError::MemoryMapper)?;
328         } else {
329             // If domain does not exist, the device SHOULD set the
330             // request status to VIRTIO_IOMMU_S_NOENT
331             tail.status = VIRTIO_IOMMU_S_NOENT;
332         }
333 
334         Ok(0)
335     }
336 
337     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
process_probe_request( &mut self, reader: &mut Reader, writer: &mut Writer, tail: &mut virtio_iommu_req_tail, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>338     fn process_probe_request(
339         &mut self,
340         reader: &mut Reader,
341         writer: &mut Writer,
342         tail: &mut virtio_iommu_req_tail,
343         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
344     ) -> Result<usize> {
345         let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
346         let endpoint: u32 = req.endpoint.into();
347 
348         // If the endpoint identified by endpoint doesn’t exist,
349         // then the device SHOULD reject the request and set status
350         // to VIRTIO_IOMMU_S_NOENT.
351         if !endpoints.borrow().contains_key(&endpoint) {
352             tail.status = VIRTIO_IOMMU_S_NOENT;
353         }
354 
355         let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
356 
357         // It's OK if properties_size is larger than probe_size
358         // We are good even if properties_size is 0
359         if properties_size < IOMMU_PROBE_SIZE {
360             // If the properties list is smaller than probe_size, the device
361             // SHOULD NOT write any property. It SHOULD reject the request
362             // and set status to VIRTIO_IOMMU_S_INVAL.
363             tail.status = VIRTIO_IOMMU_S_INVAL;
364         } else if tail.status == VIRTIO_IOMMU_S_OK {
365             const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
366             const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
367             const PROBE_PROPERTY_SIZE: u16 = 4;
368             const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
369             const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
370 
371             let properties = virtio_iommu_probe_resv_mem {
372                 head: virtio_iommu_probe_property {
373                     type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
374                     length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
375                 },
376                 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
377                 start: X86_MSI_IOVA_START.into(),
378                 end: X86_MSI_IOVA_END.into(),
379                 ..Default::default()
380             };
381             writer
382                 .write_all(properties.as_slice())
383                 .map_err(IommuError::GuestMemoryWrite)?;
384         }
385 
386         // If the device doesn’t fill all probe_size bytes with properties,
387         // it SHOULD fill the remaining bytes of properties with zeroes.
388         let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
389 
390         if remaining_bytes > 0 {
391             let buffer: Vec<u8> = vec![0; remaining_bytes];
392             writer
393                 .write_all(buffer.as_slice())
394                 .map_err(IommuError::GuestMemoryWrite)?;
395         }
396 
397         Ok(properties_size)
398     }
399 
execute_request( &mut self, avail_desc: &DescriptorChain, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<usize>400     fn execute_request(
401         &mut self,
402         avail_desc: &DescriptorChain,
403         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
404     ) -> Result<usize> {
405         let mut reader =
406             Reader::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateReader)?;
407         let mut writer =
408             Writer::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateWriter)?;
409 
410         // at least we need space to write VirtioIommuReqTail
411         if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
412             return Err(IommuError::WriteBufferTooSmall);
413         }
414 
415         let req_head: virtio_iommu_req_head =
416             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
417 
418         let mut tail = virtio_iommu_req_tail {
419             status: VIRTIO_IOMMU_S_OK,
420             ..Default::default()
421         };
422 
423         let reply_len = match req_head.type_ {
424             VIRTIO_IOMMU_T_ATTACH => {
425                 self.process_attach_request(&mut reader, &mut tail, endpoints)?
426             }
427             VIRTIO_IOMMU_T_DETACH => {
428                 // A few reasons why we don't support VIRTIO_IOMMU_T_DETACH for now:
429                 //
430                 // 1. Linux virtio IOMMU front-end driver doesn't implement VIRTIO_IOMMU_T_DETACH request
431                 // 2. Seems it's not possible to dynamically attach and detach a IOMMU domain if the
432                 //    virtio IOMMU device is running on top of VFIO
433                 // 3. Even if VIRTIO_IOMMU_T_DETACH is implemented in front-end driver, it could violate
434                 //    the following virtio IOMMU spec: Detach an endpoint from a domain. when this request
435                 //    completes, the endpoint cannot access any mapping from that domain anymore.
436                 //
437                 //    This is because VFIO doesn't support detaching a single device. When the virtio-iommu
438                 //    device receives a VIRTIO_IOMMU_T_DETACH request, it can either to:
439                 //    - detach a group: any other endpoints in the group lose access to the domain.
440                 //    - do not detach the group at all: this breaks the above mentioned spec.
441                 tail.status = VIRTIO_IOMMU_S_UNSUPP;
442                 0
443             }
444             VIRTIO_IOMMU_T_MAP => self.process_dma_map_request(&mut reader, &mut tail)?,
445             VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(&mut reader, &mut tail)?,
446             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
447             VIRTIO_IOMMU_T_PROBE => {
448                 self.process_probe_request(&mut reader, &mut writer, &mut tail, endpoints)?
449             }
450             _ => return Err(IommuError::UnexpectedDescriptor),
451         };
452 
453         writer
454             .write_all(tail.as_slice())
455             .map_err(IommuError::GuestMemoryWrite)?;
456         Ok((reply_len as usize) + size_of::<virtio_iommu_req_tail>())
457     }
458 
request_queue<I: SignalableInterrupt>( &mut self, mut queue: Queue, mut queue_event: EventAsync, interrupt: &I, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> Result<()>459     async fn request_queue<I: SignalableInterrupt>(
460         &mut self,
461         mut queue: Queue,
462         mut queue_event: EventAsync,
463         interrupt: &I,
464         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
465     ) -> Result<()> {
466         loop {
467             let avail_desc = queue
468                 .next_async(&self.mem, &mut queue_event)
469                 .await
470                 .map_err(IommuError::ReadAsyncDesc)?;
471             let desc_index = avail_desc.index;
472 
473             let len = match self.execute_request(&avail_desc, endpoints) {
474                 Ok(len) => len as u32,
475                 Err(e) => {
476                     error!("execute_request failed: {}", e);
477 
478                     // If a request type is not recognized, the device SHOULD NOT write
479                     // the buffer and SHOULD set the used length to zero
480                     0
481                 }
482             };
483 
484             queue.add_used(&self.mem, desc_index, len as u32);
485             queue.trigger_interrupt(&self.mem, interrupt);
486         }
487     }
488 
handle_add_vfio_device( mem: &GuestMemory, endpoint_addr: u32, container_fd: File, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> VirtioIOMMUVfioResult489     fn handle_add_vfio_device(
490         mem: &GuestMemory,
491         endpoint_addr: u32,
492         container_fd: File,
493         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
494         hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
495     ) -> VirtioIOMMUVfioResult {
496         let exists = |endpoint_addr: u32| -> bool {
497             for endpoints_range in hp_endpoints_ranges.iter() {
498                 if endpoints_range.contains(&endpoint_addr) {
499                     return true;
500                 }
501             }
502             false
503         };
504 
505         if !exists(endpoint_addr) {
506             return VirtioIOMMUVfioResult::NotInPCIRanges;
507         }
508 
509         let vfio_container = match VfioContainer::new_from_container(container_fd) {
510             Ok(vfio_container) => vfio_container,
511             Err(e) => {
512                 error!("failed to verify the new container: {}", e);
513                 return VirtioIOMMUVfioResult::NoAvailableContainer;
514             }
515         };
516         endpoints.borrow_mut().insert(
517             endpoint_addr,
518             Arc::new(Mutex::new(Box::new(VfioWrapper::new(
519                 Arc::new(Mutex::new(vfio_container)),
520                 mem.clone(),
521             )))),
522         );
523         VirtioIOMMUVfioResult::Ok
524     }
525 
handle_del_vfio_device( pci_address: u32, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, ) -> VirtioIOMMUVfioResult526     fn handle_del_vfio_device(
527         pci_address: u32,
528         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
529     ) -> VirtioIOMMUVfioResult {
530         if endpoints.borrow_mut().remove(&pci_address).is_none() {
531             error!("There is no vfio container of {}", pci_address);
532             return VirtioIOMMUVfioResult::NoSuchDevice;
533         }
534         VirtioIOMMUVfioResult::Ok
535     }
536 
handle_vfio( mem: &GuestMemory, vfio_cmd: VirtioIOMMUVfioCommand, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> VirtioIOMMUResponse537     fn handle_vfio(
538         mem: &GuestMemory,
539         vfio_cmd: VirtioIOMMUVfioCommand,
540         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
541         hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
542     ) -> VirtioIOMMUResponse {
543         use VirtioIOMMUVfioCommand::*;
544         let vfio_result = match vfio_cmd {
545             VfioDeviceAdd {
546                 endpoint_addr,
547                 container,
548             } => Self::handle_add_vfio_device(
549                 mem,
550                 endpoint_addr,
551                 container,
552                 endpoints,
553                 hp_endpoints_ranges,
554             ),
555             VfioDeviceDel { endpoint_addr } => {
556                 Self::handle_del_vfio_device(endpoint_addr, endpoints)
557             }
558         };
559         VirtioIOMMUResponse::VfioResponse(vfio_result)
560     }
561 
562     // Async task that handles messages from the host
handle_command_tube( mem: &GuestMemory, command_tube: AsyncTube, endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>, ) -> Result<()>563     async fn handle_command_tube(
564         mem: &GuestMemory,
565         command_tube: AsyncTube,
566         endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
567         hp_endpoints_ranges: &Rc<Vec<RangeInclusive<u32>>>,
568     ) -> Result<()> {
569         loop {
570             match command_tube.next::<VirtioIOMMURequest>().await {
571                 Ok(command) => {
572                     let response: VirtioIOMMUResponse = match command {
573                         VirtioIOMMURequest::VfioCommand(vfio_cmd) => {
574                             Self::handle_vfio(mem, vfio_cmd, endpoints, hp_endpoints_ranges)
575                         }
576                     };
577                     if let Err(e) = command_tube.send(response).await {
578                         error!("{}", IommuError::VirtioIOMMUResponseError(e));
579                     }
580                 }
581                 Err(e) => {
582                     return Err(IommuError::VirtioIOMMUReqError(e));
583                 }
584             }
585         }
586     }
587 
run( &mut self, iommu_device_tube: Tube, mut queues: Vec<Queue>, queue_evts: Vec<Event>, kill_evt: Event, interrupt: Interrupt, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, ) -> Result<()>588     fn run(
589         &mut self,
590         iommu_device_tube: Tube,
591         mut queues: Vec<Queue>,
592         queue_evts: Vec<Event>,
593         kill_evt: Event,
594         interrupt: Interrupt,
595         endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
596         translate_response_senders: Option<BTreeMap<u32, Tube>>,
597         translate_request_rx: Option<Tube>,
598     ) -> Result<()> {
599         let ex = Executor::new().expect("Failed to create an executor");
600 
601         let mut evts_async: Vec<EventAsync> = queue_evts
602             .into_iter()
603             .map(|e| EventAsync::new(e.0, &ex).expect("Failed to create async event for queue"))
604             .collect();
605         let interrupt = Rc::new(RefCell::new(interrupt));
606         let interrupt_ref = &*interrupt.borrow();
607 
608         let (req_queue, req_evt) = (queues.remove(0), evts_async.remove(0));
609 
610         let hp_endpoints_ranges = Rc::new(self.hp_endpoints_ranges.clone());
611         let mem = Rc::new(self.mem.clone());
612         // contains all pass-through endpoints that attach to this IOMMU device
613         // key: endpoint PCI address
614         // value: reference counter and MemoryMapperTrait
615         let endpoints: Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>> =
616             Rc::new(RefCell::new(endpoints));
617 
618         let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
619         let f_kill = async_utils::await_and_exit(&ex, kill_evt);
620 
621         let request_tube = translate_request_rx
622             .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
623         let response_tubes = translate_response_senders.map(|m| {
624             m.into_iter()
625                 .map(|x| {
626                     (
627                         x.0,
628                         AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
629                     )
630                 })
631                 .collect()
632         });
633 
634         let f_handle_translate_request =
635             handle_translate_request(&endpoints, request_tube, response_tubes);
636         let f_request = self.request_queue(req_queue, req_evt, interrupt_ref, &endpoints);
637 
638         let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
639         // Future to handle command messages from host, such as passing vfio containers.
640         let f_cmd = Self::handle_command_tube(&mem, command_tube, &endpoints, &hp_endpoints_ranges);
641 
642         let done = async {
643             select! {
644                 res = f_request.fuse() => res.context("error in handling request queue"),
645                 res = f_resample.fuse() => res.context("error in handle_irq_resample"),
646                 res = f_kill.fuse() => res.context("error in await_and_exit"),
647                 res = f_handle_translate_request.fuse() => res.context("error in handle_translate_request"),
648                 res = f_cmd.fuse() => res.context("error in handling host request"),
649             }
650         };
651         match ex.run_until(done) {
652             Ok(Ok(())) => {}
653             Ok(Err(e)) => error!("Error in worker: {}", e),
654             Err(e) => return Err(IommuError::AsyncExec(e)),
655         }
656 
657         Ok(())
658     }
659 }
660 
handle_translate_request( endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>, request_tube: Option<AsyncTube>, response_tubes: Option<BTreeMap<u32, AsyncTube>>, ) -> Result<()>661 async fn handle_translate_request(
662     endpoints: &Rc<RefCell<BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>>>,
663     request_tube: Option<AsyncTube>,
664     response_tubes: Option<BTreeMap<u32, AsyncTube>>,
665 ) -> Result<()> {
666     let request_tube = match request_tube {
667         Some(r) => r,
668         None => {
669             let () = futures::future::pending().await;
670             return Ok(());
671         }
672     };
673     let response_tubes = response_tubes.unwrap();
674     loop {
675         let TranslateRequest {
676             endpoint_id,
677             iova,
678             size,
679         } = request_tube.next().await.map_err(IommuError::Tube)?;
680         if let Some(mapper) = endpoints.borrow_mut().get(&endpoint_id) {
681             response_tubes
682                 .get(&endpoint_id)
683                 .unwrap()
684                 .send(
685                     mapper
686                         .lock()
687                         .translate(iova, size)
688                         .map_err(|e| {
689                             error!("Failed to handle TranslateRequest: {}", e);
690                             e
691                         })
692                         .ok(),
693                 )
694                 .await
695                 .map_err(IommuError::Tube)?;
696         } else {
697             error!("endpoint_id {} not found", endpoint_id)
698         }
699     }
700 }
701 
702 /// Virtio device for IOMMU memory management.
703 pub struct Iommu {
704     kill_evt: Option<Event>,
705     worker_thread: Option<thread::JoinHandle<Worker>>,
706     config: virtio_iommu_config,
707     avail_features: u64,
708     // Attached endpoints
709     // key: endpoint PCI address
710     // value: reference counter and MemoryMapperTrait
711     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
712     // Hot-pluggable PCI endpoints ranges
713     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
714     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
715     translate_response_senders: Option<BTreeMap<u32, Tube>>,
716     translate_request_rx: Option<Tube>,
717     iommu_device_tube: Option<Tube>,
718 }
719 
720 impl Iommu {
721     /// Create a new virtio IOMMU device.
new( base_features: u64, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, phys_max_addr: u64, hp_endpoints_ranges: Vec<RangeInclusive<u32>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, iommu_device_tube: Option<Tube>, ) -> SysResult<Iommu>722     pub fn new(
723         base_features: u64,
724         endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
725         phys_max_addr: u64,
726         hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
727         translate_response_senders: Option<BTreeMap<u32, Tube>>,
728         translate_request_rx: Option<Tube>,
729         iommu_device_tube: Option<Tube>,
730     ) -> SysResult<Iommu> {
731         let mut page_size_mask = !0_u64;
732         for (_, container) in endpoints.iter() {
733             page_size_mask &= container
734                 .lock()
735                 .get_mask()
736                 .map_err(|_e| SysError::new(libc::EIO))?;
737         }
738 
739         if page_size_mask == 0 {
740             // In case no endpoints bounded to vIOMMU during system booting,
741             // assume IOVA page size is the same as page_size
742             page_size_mask = (pagesize() as u64) - 1;
743         }
744 
745         let input_range = virtio_iommu_range_64 {
746             start: Le64::from(0),
747             end: phys_max_addr.into(),
748         };
749 
750         let config = virtio_iommu_config {
751             page_size_mask: page_size_mask.into(),
752             input_range,
753             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
754             probe_size: (IOMMU_PROBE_SIZE as u32).into(),
755             ..Default::default()
756         };
757 
758         let mut avail_features: u64 = base_features;
759         avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP | 1 << VIRTIO_IOMMU_F_INPUT_RANGE;
760 
761         if cfg!(any(target_arch = "x86", target_arch = "x86_64")) {
762             avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
763         }
764 
765         Ok(Iommu {
766             kill_evt: None,
767             worker_thread: None,
768             config,
769             avail_features,
770             endpoints,
771             hp_endpoints_ranges,
772             translate_response_senders,
773             translate_request_rx,
774             iommu_device_tube,
775         })
776     }
777 }
778 
779 impl Drop for Iommu {
drop(&mut self)780     fn drop(&mut self) {
781         if let Some(kill_evt) = self.kill_evt.take() {
782             let _ = kill_evt.write(1);
783         }
784 
785         if let Some(worker_thread) = self.worker_thread.take() {
786             let _ = worker_thread.join();
787         }
788     }
789 }
790 
791 impl VirtioDevice for Iommu {
keep_rds(&self) -> Vec<RawDescriptor>792     fn keep_rds(&self) -> Vec<RawDescriptor> {
793         let mut rds = Vec::new();
794 
795         for (_, mapper) in self.endpoints.iter() {
796             rds.append(&mut mapper.lock().as_raw_descriptors());
797         }
798         if let Some(senders) = &self.translate_response_senders {
799             for (_, tube) in senders.iter() {
800                 rds.push(tube.as_raw_descriptor());
801             }
802         }
803         if let Some(rx) = &self.translate_request_rx {
804             rds.push(rx.as_raw_descriptor());
805         }
806 
807         if let Some(iommu_device_tube) = &self.iommu_device_tube {
808             rds.push(iommu_device_tube.as_raw_descriptor());
809         }
810 
811         rds
812     }
813 
device_type(&self) -> u32814     fn device_type(&self) -> u32 {
815         TYPE_IOMMU
816     }
817 
queue_max_sizes(&self) -> &[u16]818     fn queue_max_sizes(&self) -> &[u16] {
819         QUEUE_SIZES
820     }
821 
features(&self) -> u64822     fn features(&self) -> u64 {
823         self.avail_features
824     }
825 
read_config(&self, offset: u64, data: &mut [u8])826     fn read_config(&self, offset: u64, data: &mut [u8]) {
827         let mut config: Vec<u8> = Vec::new();
828         config.extend_from_slice(self.config.as_slice());
829         copy_config(data, 0, config.as_slice(), offset);
830     }
831 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: Vec<Queue>, queue_evts: Vec<Event>, )832     fn activate(
833         &mut self,
834         mem: GuestMemory,
835         interrupt: Interrupt,
836         queues: Vec<Queue>,
837         queue_evts: Vec<Event>,
838     ) {
839         if queues.len() != QUEUE_SIZES.len() || queue_evts.len() != QUEUE_SIZES.len() {
840             return;
841         }
842 
843         let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
844             Ok(v) => v,
845             Err(e) => {
846                 error!("failed to create kill Event pair: {}", e);
847                 return;
848             }
849         };
850         self.kill_evt = Some(self_kill_evt);
851 
852         // The least significant bit of page_size_masks defines the page
853         // granularity of IOMMU mappings
854         let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
855         let eps = self.endpoints.clone();
856         let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
857 
858         let translate_response_senders = self.translate_response_senders.take();
859         let translate_request_rx = self.translate_request_rx.take();
860 
861         match self.iommu_device_tube.take() {
862             Some(iommu_device_tube) => {
863                 let worker_result = thread::Builder::new()
864                     .name("virtio_iommu".to_string())
865                     .spawn(move || {
866                         let mut worker = Worker {
867                             mem,
868                             page_mask,
869                             hp_endpoints_ranges,
870                             endpoint_map: BTreeMap::new(),
871                             domain_map: BTreeMap::new(),
872                         };
873                         let result = worker.run(
874                             iommu_device_tube,
875                             queues,
876                             queue_evts,
877                             kill_evt,
878                             interrupt,
879                             eps,
880                             translate_response_senders,
881                             translate_request_rx,
882                         );
883                         if let Err(e) = result {
884                             error!("virtio-iommu worker thread exited with error: {}", e);
885                         }
886                         worker
887                     });
888 
889                 match worker_result {
890                     Err(e) => error!("failed to spawn virtio_iommu worker thread: {}", e),
891                     Ok(join_handle) => self.worker_thread = Some(join_handle),
892                 }
893             }
894             None => {
895                 error!("failed to start virtio-iommu worker: No control tube");
896                 return;
897             }
898         }
899     }
900 
901     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
generate_acpi( &mut self, pci_address: &Option<PciAddress>, mut sdts: Vec<SDT>, ) -> Option<Vec<SDT>>902     fn generate_acpi(
903         &mut self,
904         pci_address: &Option<PciAddress>,
905         mut sdts: Vec<SDT>,
906     ) -> Option<Vec<SDT>> {
907         const OEM_REVISION: u32 = 1;
908         const VIOT_REVISION: u8 = 0;
909 
910         for sdt in sdts.iter() {
911             // there should only be one VIOT table
912             if sdt.is_signature(b"VIOT") {
913                 warn!("vIOMMU: duplicate VIOT table detected");
914                 return None;
915             }
916         }
917 
918         let mut viot = SDT::new(
919             *b"VIOT",
920             acpi_tables::HEADER_LEN,
921             VIOT_REVISION,
922             *b"CROSVM",
923             *b"CROSVMDT",
924             OEM_REVISION,
925         );
926         viot.append(VirtioIommuViotHeader {
927             // # of PCI range nodes + 1 virtio-pci node
928             node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
929             node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
930             ..Default::default()
931         });
932 
933         let bdf = pci_address
934             .or_else(|| {
935                 error!("vIOMMU device has no PCI address");
936                 None
937             })?
938             .to_u32() as u16;
939         let iommu_offset = viot.len();
940 
941         viot.append(VirtioIommuViotVirtioPciNode {
942             type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
943             length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
944             bdf,
945             ..Default::default()
946         });
947 
948         for (endpoint, _) in self.endpoints.iter() {
949             viot.append(VirtioIommuViotPciRangeNode {
950                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
951                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
952                 endpoint_start: *endpoint,
953                 bdf_start: *endpoint as u16,
954                 bdf_end: *endpoint as u16,
955                 output_node: iommu_offset as u16,
956                 ..Default::default()
957             });
958         }
959 
960         for endpoints_range in self.hp_endpoints_ranges.iter() {
961             let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
962             viot.append(VirtioIommuViotPciRangeNode {
963                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
964                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
965                 endpoint_start,
966                 bdf_start: endpoint_start as u16,
967                 bdf_end: endpoint_end as u16,
968                 output_node: iommu_offset as u16,
969                 ..Default::default()
970             });
971         }
972 
973         sdts.push(viot);
974         Some(sdts)
975     }
976 }
977