• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 pub mod ipc_memory_mapper;
6 pub mod memory_mapper;
7 pub mod memory_util;
8 pub mod protocol;
9 pub(crate) mod sys;
10 
11 use std::cell::RefCell;
12 use std::collections::btree_map::Entry;
13 use std::collections::BTreeMap;
14 use std::io;
15 use std::io::Write;
16 use std::mem::size_of;
17 use std::ops::RangeInclusive;
18 use std::rc::Rc;
19 use std::result;
20 use std::sync::Arc;
21 
22 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
23 use acpi_tables::sdt::SDT;
24 use anyhow::anyhow;
25 use anyhow::Context;
26 use base::debug;
27 use base::error;
28 use base::pagesize;
29 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
30 use base::warn;
31 use base::AsRawDescriptor;
32 use base::Error as SysError;
33 use base::Event;
34 use base::MappedRegion;
35 use base::MemoryMapping;
36 use base::Protection;
37 use base::RawDescriptor;
38 use base::Result as SysResult;
39 use base::Tube;
40 use base::TubeError;
41 use base::WorkerThread;
42 use cros_async::AsyncError;
43 use cros_async::AsyncTube;
44 use cros_async::EventAsync;
45 use cros_async::Executor;
46 use data_model::Le64;
47 use futures::select;
48 use futures::FutureExt;
49 use hypervisor::MemSlot;
50 use remain::sorted;
51 use sync::Mutex;
52 use thiserror::Error;
53 use vm_memory::GuestAddress;
54 use vm_memory::GuestMemory;
55 use vm_memory::GuestMemoryError;
56 use zerocopy::AsBytes;
57 use zerocopy::FromBytes;
58 
59 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
60 use crate::pci::PciAddress;
61 use crate::virtio::async_utils;
62 use crate::virtio::copy_config;
63 use crate::virtio::iommu::ipc_memory_mapper::*;
64 use crate::virtio::iommu::memory_mapper::*;
65 use crate::virtio::iommu::protocol::*;
66 use crate::virtio::DescriptorChain;
67 use crate::virtio::DescriptorError;
68 use crate::virtio::DeviceType;
69 use crate::virtio::Interrupt;
70 use crate::virtio::Queue;
71 use crate::virtio::Reader;
72 use crate::virtio::SignalableInterrupt;
73 use crate::virtio::VirtioDevice;
74 use crate::virtio::Writer;
75 use crate::Suspendable;
76 
77 const QUEUE_SIZE: u16 = 256;
78 const NUM_QUEUES: usize = 2;
79 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
80 
81 // Size of struct virtio_iommu_probe_property
82 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
83 const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
84 
85 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
86 const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
87 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
88 const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
89 
90 #[derive(Copy, Clone, Debug, Default, FromBytes, AsBytes)]
91 #[repr(C, packed)]
92 struct VirtioIommuViotHeader {
93     node_count: u16,
94     node_offset: u16,
95     reserved: [u8; 8],
96 }
97 
98 #[derive(Copy, Clone, Debug, Default, FromBytes, AsBytes)]
99 #[repr(C, packed)]
100 struct VirtioIommuViotVirtioPciNode {
101     type_: u8,
102     reserved: [u8; 1],
103     length: u16,
104     segment: u16,
105     bdf: u16,
106     reserved2: [u8; 8],
107 }
108 
109 #[derive(Copy, Clone, Debug, Default, FromBytes, AsBytes)]
110 #[repr(C, packed)]
111 struct VirtioIommuViotPciRangeNode {
112     type_: u8,
113     reserved: [u8; 1],
114     length: u16,
115     endpoint_start: u32,
116     segment_start: u16,
117     segment_end: u16,
118     bdf_start: u16,
119     bdf_end: u16,
120     output_node: u16,
121     reserved2: [u8; 2],
122     reserved3: [u8; 4],
123 }
124 
125 type Result<T> = result::Result<T, IommuError>;
126 
127 #[sorted]
128 #[derive(Error, Debug)]
129 pub enum IommuError {
130     #[error("async executor error: {0}")]
131     AsyncExec(AsyncError),
132     #[error("failed to create reader: {0}")]
133     CreateReader(DescriptorError),
134     #[error("failed to create wait context: {0}")]
135     CreateWaitContext(SysError),
136     #[error("failed to create writer: {0}")]
137     CreateWriter(DescriptorError),
138     #[error("failed getting host address: {0}")]
139     GetHostAddress(GuestMemoryError),
140     #[error("failed to read from guest address: {0}")]
141     GuestMemoryRead(io::Error),
142     #[error("failed to write to guest address: {0}")]
143     GuestMemoryWrite(io::Error),
144     #[error("memory mapper failed: {0}")]
145     MemoryMapper(anyhow::Error),
146     #[error("Failed to read descriptor asynchronously: {0}")]
147     ReadAsyncDesc(AsyncError),
148     #[error("failed to read from virtio queue Event: {0}")]
149     ReadQueueEvent(SysError),
150     #[error("tube error: {0}")]
151     Tube(TubeError),
152     #[error("unexpected descriptor error")]
153     UnexpectedDescriptor,
154     #[error("failed to receive virtio-iommu control request: {0}")]
155     VirtioIOMMUReqError(TubeError),
156     #[error("failed to send virtio-iommu control response: {0}")]
157     VirtioIOMMUResponseError(TubeError),
158     #[error("failed to wait for events: {0}")]
159     WaitError(SysError),
160     #[error("write buffer length too small")]
161     WriteBufferTooSmall,
162 }
163 
164 // key: domain ID
165 // value: reference counter and MemoryMapperTrait
166 type DomainMap = BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>;
167 
168 struct DmabufRegionEntry {
169     mmap: MemoryMapping,
170     mem_slot: MemSlot,
171     len: u64,
172 }
173 
174 // Shared state for the virtio-iommu device.
175 struct State {
176     mem: GuestMemory,
177     page_mask: u64,
178     // Hot-pluggable PCI endpoints ranges
179     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
180     #[cfg_attr(windows, allow(dead_code))]
181     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
182     // All PCI endpoints that attach to certain IOMMU domain
183     // key: endpoint PCI address
184     // value: attached domain ID
185     endpoint_map: BTreeMap<u32, u32>,
186     // All attached domains
187     domain_map: DomainMap,
188     // Contains all pass-through endpoints that attach to this IOMMU device
189     // key: endpoint PCI address
190     // value: reference counter and MemoryMapperTrait
191     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
192     // Contains dmabuf regions
193     // key: guest physical address
194     dmabuf_mem: BTreeMap<u64, DmabufRegionEntry>,
195 }
196 
197 impl State {
198     // Detach the given endpoint if possible, and return whether or not the endpoint
199     // was actually detached. If a successfully detached endpoint has exported
200     // memory, returns an event that will be signaled once all exported memory is released.
201     //
202     // The device MUST ensure that after being detached from a domain, the endpoint
203     // cannot access any mapping from that domain.
204     //
205     // Currently, we only support detaching an endpoint if it is the only endpoint attached
206     // to its domain.
detach_endpoint( endpoint_map: &mut BTreeMap<u32, u32>, domain_map: &mut DomainMap, endpoint: u32, ) -> (bool, Option<EventAsync>)207     fn detach_endpoint(
208         endpoint_map: &mut BTreeMap<u32, u32>,
209         domain_map: &mut DomainMap,
210         endpoint: u32,
211     ) -> (bool, Option<EventAsync>) {
212         let mut evt = None;
213         // The endpoint has attached to an IOMMU domain
214         if let Some(attached_domain) = endpoint_map.get(&endpoint) {
215             // Remove the entry or update the domain reference count
216             if let Entry::Occupied(o) = domain_map.entry(*attached_domain) {
217                 let (refs, mapper) = o.get();
218                 if !mapper.lock().supports_detach() {
219                     return (false, None);
220                 }
221 
222                 match refs {
223                     0 => unreachable!(),
224                     1 => {
225                         evt = mapper.lock().reset_domain();
226                         o.remove();
227                     }
228                     _ => return (false, None),
229                 }
230             }
231         }
232 
233         endpoint_map.remove(&endpoint);
234         (true, evt)
235     }
236 
237     // Processes an attach request. This may require detaching the endpoint from
238     // its current endpoint before attaching it to a new endpoint. If that happens
239     // while the endpoint has exported memory, this function returns an event that
240     // will be signaled once all exported memory is released.
241     //
242     // Notes: if a VFIO group contains multiple devices, it could violate the follow
243     // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
244     // is negotiated, all accesses from unattached endpoints are allowed and translated
245     // by the IOMMU using the identity function. If the feature is not negotiated, any
246     // memory access from an unattached endpoint fails.
247     //
248     // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
249     // request for the first endpoint in a VFIO group, any not yet attached endpoints
250     // in the VFIO group will be able to access the domain.
251     //
252     // This violation is benign for current virtualization use cases. Since device
253     // topology in the guest matches topology in the host, the guest doesn't expect
254     // the device in the same VFIO group are isolated from each other in the first place.
process_attach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>255     fn process_attach_request(
256         &mut self,
257         reader: &mut Reader,
258         tail: &mut virtio_iommu_req_tail,
259     ) -> Result<(usize, Option<EventAsync>)> {
260         let req: virtio_iommu_req_attach =
261             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
262         let mut fault_resolved_event = None;
263 
264         // If the reserved field of an ATTACH request is not zero,
265         // the device MUST reject the request and set status to
266         // VIRTIO_IOMMU_S_INVAL.
267         if req.reserved.iter().any(|&x| x != 0) {
268             tail.status = VIRTIO_IOMMU_S_INVAL;
269             return Ok((0, None));
270         }
271 
272         let domain: u32 = req.domain.into();
273         let endpoint: u32 = req.endpoint.into();
274 
275         if let Some(mapper) = self.endpoints.get(&endpoint) {
276             // The same mapper can't be used for two domains at the same time,
277             // since that would result in conflicts/permission leaks between
278             // the two domains.
279             let mapper_id = {
280                 let m = mapper.lock();
281                 ((**m).type_id(), m.id())
282             };
283             for (other_endpoint, other_mapper) in self.endpoints.iter() {
284                 if *other_endpoint == endpoint {
285                     continue;
286                 }
287                 let other_id = {
288                     let m = other_mapper.lock();
289                     ((**m).type_id(), m.id())
290                 };
291                 if mapper_id == other_id {
292                     if !self
293                         .endpoint_map
294                         .get(other_endpoint)
295                         .map_or(true, |d| d == &domain)
296                     {
297                         tail.status = VIRTIO_IOMMU_S_UNSUPP;
298                         return Ok((0, None));
299                     }
300                 }
301             }
302 
303             // If the endpoint identified by `endpoint` is already attached
304             // to another domain, then the device SHOULD first detach it
305             // from that domain and attach it to the one identified by domain.
306             if self.endpoint_map.contains_key(&endpoint) {
307                 // In that case the device SHOULD behave as if the driver issued
308                 // a DETACH request with this endpoint, followed by the ATTACH
309                 // request. If the device cannot do so, it MUST reject the request
310                 // and set status to VIRTIO_IOMMU_S_UNSUPP.
311                 let (detached, evt) =
312                     Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
313                 if !detached {
314                     tail.status = VIRTIO_IOMMU_S_UNSUPP;
315                     return Ok((0, None));
316                 }
317                 fault_resolved_event = evt;
318             }
319 
320             let new_ref = match self.domain_map.get(&domain) {
321                 None => 1,
322                 Some(val) => val.0 + 1,
323             };
324 
325             self.endpoint_map.insert(endpoint, domain);
326             self.domain_map.insert(domain, (new_ref, mapper.clone()));
327         } else {
328             // If the endpoint identified by endpoint doesn’t exist,
329             // the device MUST reject the request and set status to
330             // VIRTIO_IOMMU_S_NOENT.
331             tail.status = VIRTIO_IOMMU_S_NOENT;
332         }
333 
334         Ok((0, fault_resolved_event))
335     }
336 
process_detach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>337     fn process_detach_request(
338         &mut self,
339         reader: &mut Reader,
340         tail: &mut virtio_iommu_req_tail,
341     ) -> Result<(usize, Option<EventAsync>)> {
342         let req: virtio_iommu_req_detach =
343             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
344 
345         // If the endpoint identified by |req.endpoint| doesn’t exist,
346         // the device MUST reject the request and set status to
347         // VIRTIO_IOMMU_S_NOENT.
348         let endpoint: u32 = req.endpoint.into();
349         if !self.endpoints.contains_key(&endpoint) {
350             tail.status = VIRTIO_IOMMU_S_NOENT;
351             return Ok((0, None));
352         }
353 
354         let (detached, evt) =
355             Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
356         if !detached {
357             tail.status = VIRTIO_IOMMU_S_UNSUPP;
358         }
359         Ok((0, evt))
360     }
361 
process_dma_map_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>362     fn process_dma_map_request(
363         &mut self,
364         reader: &mut Reader,
365         tail: &mut virtio_iommu_req_tail,
366     ) -> Result<usize> {
367         let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
368 
369         // If virt_start, phys_start or (virt_end + 1) is not aligned
370         // on the page granularity, the device SHOULD reject the
371         // request and set status to VIRTIO_IOMMU_S_RANGE
372         if self.page_mask & u64::from(req.phys_start) != 0
373             || self.page_mask & u64::from(req.virt_start) != 0
374             || self.page_mask & (u64::from(req.virt_end) + 1) != 0
375         {
376             tail.status = VIRTIO_IOMMU_S_RANGE;
377             return Ok(0);
378         }
379 
380         // If the device doesn’t recognize a flags bit, it MUST reject
381         // the request and set status to VIRTIO_IOMMU_S_INVAL.
382         if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
383             tail.status = VIRTIO_IOMMU_S_INVAL;
384             return Ok(0);
385         }
386 
387         let domain: u32 = req.domain.into();
388         if !self.domain_map.contains_key(&domain) {
389             // If domain does not exist, the device SHOULD reject
390             // the request and set status to VIRTIO_IOMMU_S_NOENT.
391             tail.status = VIRTIO_IOMMU_S_NOENT;
392             return Ok(0);
393         }
394 
395         // The device MUST NOT allow writes to a range mapped
396         // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
397         let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
398 
399         if let Some(mapper) = self.domain_map.get(&domain) {
400             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1u64;
401 
402             let dmabuf_map = self
403                 .dmabuf_mem
404                 .range(..=u64::from(req.phys_start))
405                 .next_back()
406                 .and_then(|(addr, region)| {
407                     if u64::from(req.phys_start) + size <= addr + region.len {
408                         Some(region.mmap.as_ptr() as u64 + (u64::from(req.phys_start) - addr))
409                     } else {
410                         None
411                     }
412                 });
413 
414             let prot = match write_en {
415                 true => Protection::read_write(),
416                 false => Protection::read(),
417             };
418 
419             let vfio_map_result = match dmabuf_map {
420                 // Safe because [dmabuf_map, dmabuf_map + size) refers to an external mmap'ed region.
421                 Some(dmabuf_map) => unsafe {
422                     mapper.1.lock().vfio_dma_map(
423                         req.virt_start.into(),
424                         dmabuf_map as u64,
425                         size,
426                         prot,
427                     )
428                 },
429                 None => mapper.1.lock().add_map(MappingInfo {
430                     iova: req.virt_start.into(),
431                     gpa: GuestAddress(req.phys_start.into()),
432                     size,
433                     prot,
434                 }),
435             };
436 
437             match vfio_map_result {
438                 Ok(AddMapResult::Ok) => (),
439                 Ok(AddMapResult::OverlapFailure) => {
440                     // If a mapping already exists in the requested range,
441                     // the device SHOULD reject the request and set status
442                     // to VIRTIO_IOMMU_S_INVAL.
443                     tail.status = VIRTIO_IOMMU_S_INVAL;
444                 }
445                 Err(e) => return Err(IommuError::MemoryMapper(e)),
446             }
447         }
448 
449         Ok(0)
450     }
451 
process_dma_unmap_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>452     fn process_dma_unmap_request(
453         &mut self,
454         reader: &mut Reader,
455         tail: &mut virtio_iommu_req_tail,
456     ) -> Result<(usize, Option<EventAsync>)> {
457         let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
458 
459         let domain: u32 = req.domain.into();
460         let fault_resolved_event = if let Some(mapper) = self.domain_map.get(&domain) {
461             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
462             let res = mapper
463                 .1
464                 .lock()
465                 .remove_map(u64::from(req.virt_start), size)
466                 .map_err(IommuError::MemoryMapper)?;
467             match res {
468                 RemoveMapResult::Success(evt) => evt,
469                 RemoveMapResult::OverlapFailure => {
470                     // If a mapping affected by the range is not covered in its entirety by the
471                     // range (the UNMAP request would split the mapping), then the device SHOULD
472                     // set the request `status` to VIRTIO_IOMMU_S_RANGE, and SHOULD NOT remove
473                     // any mapping.
474                     tail.status = VIRTIO_IOMMU_S_RANGE;
475                     None
476                 }
477             }
478         } else {
479             // If domain does not exist, the device SHOULD set the
480             // request status to VIRTIO_IOMMU_S_NOENT
481             tail.status = VIRTIO_IOMMU_S_NOENT;
482             None
483         };
484 
485         Ok((0, fault_resolved_event))
486     }
487 
488     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
process_probe_request( &mut self, reader: &mut Reader, writer: &mut Writer, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>489     fn process_probe_request(
490         &mut self,
491         reader: &mut Reader,
492         writer: &mut Writer,
493         tail: &mut virtio_iommu_req_tail,
494     ) -> Result<usize> {
495         let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
496         let endpoint: u32 = req.endpoint.into();
497 
498         // If the endpoint identified by endpoint doesn’t exist,
499         // then the device SHOULD reject the request and set status
500         // to VIRTIO_IOMMU_S_NOENT.
501         if !self.endpoints.contains_key(&endpoint) {
502             tail.status = VIRTIO_IOMMU_S_NOENT;
503         }
504 
505         let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
506 
507         // It's OK if properties_size is larger than probe_size
508         // We are good even if properties_size is 0
509         if properties_size < IOMMU_PROBE_SIZE {
510             // If the properties list is smaller than probe_size, the device
511             // SHOULD NOT write any property. It SHOULD reject the request
512             // and set status to VIRTIO_IOMMU_S_INVAL.
513             tail.status = VIRTIO_IOMMU_S_INVAL;
514         } else if tail.status == VIRTIO_IOMMU_S_OK {
515             const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
516             const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
517             const PROBE_PROPERTY_SIZE: u16 = 4;
518             const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
519             const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
520 
521             let properties = virtio_iommu_probe_resv_mem {
522                 head: virtio_iommu_probe_property {
523                     type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
524                     length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
525                 },
526                 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
527                 start: X86_MSI_IOVA_START.into(),
528                 end: X86_MSI_IOVA_END.into(),
529                 ..Default::default()
530             };
531             writer
532                 .write_all(properties.as_bytes())
533                 .map_err(IommuError::GuestMemoryWrite)?;
534         }
535 
536         // If the device doesn’t fill all probe_size bytes with properties,
537         // it SHOULD fill the remaining bytes of properties with zeroes.
538         let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
539 
540         if remaining_bytes > 0 {
541             let buffer: Vec<u8> = vec![0; remaining_bytes];
542             writer
543                 .write_all(buffer.as_slice())
544                 .map_err(IommuError::GuestMemoryWrite)?;
545         }
546 
547         Ok(properties_size)
548     }
549 
execute_request( &mut self, avail_desc: &DescriptorChain, ) -> Result<(usize, Option<EventAsync>)>550     fn execute_request(
551         &mut self,
552         avail_desc: &DescriptorChain,
553     ) -> Result<(usize, Option<EventAsync>)> {
554         let mut reader =
555             Reader::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateReader)?;
556         let mut writer =
557             Writer::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateWriter)?;
558 
559         // at least we need space to write VirtioIommuReqTail
560         if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
561             return Err(IommuError::WriteBufferTooSmall);
562         }
563 
564         let req_head: virtio_iommu_req_head =
565             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
566 
567         let mut tail = virtio_iommu_req_tail {
568             status: VIRTIO_IOMMU_S_OK,
569             ..Default::default()
570         };
571 
572         let (reply_len, fault_resolved_event) = match req_head.type_ {
573             VIRTIO_IOMMU_T_ATTACH => self.process_attach_request(&mut reader, &mut tail)?,
574             VIRTIO_IOMMU_T_DETACH => self.process_detach_request(&mut reader, &mut tail)?,
575             VIRTIO_IOMMU_T_MAP => (self.process_dma_map_request(&mut reader, &mut tail)?, None),
576             VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(&mut reader, &mut tail)?,
577             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
578             VIRTIO_IOMMU_T_PROBE => (
579                 self.process_probe_request(&mut reader, &mut writer, &mut tail)?,
580                 None,
581             ),
582             _ => return Err(IommuError::UnexpectedDescriptor),
583         };
584 
585         writer
586             .write_all(tail.as_bytes())
587             .map_err(IommuError::GuestMemoryWrite)?;
588         Ok((
589             (reply_len as usize) + size_of::<virtio_iommu_req_tail>(),
590             fault_resolved_event,
591         ))
592     }
593 }
594 
request_queue<I: SignalableInterrupt>( state: &Rc<RefCell<State>>, mut queue: Queue, mut queue_event: EventAsync, interrupt: I, ) -> Result<()>595 async fn request_queue<I: SignalableInterrupt>(
596     state: &Rc<RefCell<State>>,
597     mut queue: Queue,
598     mut queue_event: EventAsync,
599     interrupt: I,
600 ) -> Result<()> {
601     loop {
602         let mem = state.borrow().mem.clone();
603         let avail_desc = queue
604             .next_async(&mem, &mut queue_event)
605             .await
606             .map_err(IommuError::ReadAsyncDesc)?;
607         let desc_index = avail_desc.index;
608 
609         let (len, fault_resolved_event) = match state.borrow_mut().execute_request(&avail_desc) {
610             Ok(res) => res,
611             Err(e) => {
612                 error!("execute_request failed: {}", e);
613 
614                 // If a request type is not recognized, the device SHOULD NOT write
615                 // the buffer and SHOULD set the used length to zero
616                 (0, None)
617             }
618         };
619 
620         if let Some(fault_resolved_event) = fault_resolved_event {
621             debug!("waiting for iommu fault resolution");
622             fault_resolved_event
623                 .next_val()
624                 .await
625                 .expect("failed waiting for fault");
626             debug!("iommu fault resolved");
627         }
628 
629         queue.add_used(&mem, desc_index, len as u32);
630         queue.trigger_interrupt(&mem, &interrupt);
631     }
632 }
633 
run( state: State, iommu_device_tube: Tube, mut queues: Vec<(Queue, Event)>, kill_evt: Event, interrupt: Interrupt, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, ) -> Result<()>634 fn run(
635     state: State,
636     iommu_device_tube: Tube,
637     mut queues: Vec<(Queue, Event)>,
638     kill_evt: Event,
639     interrupt: Interrupt,
640     translate_response_senders: Option<BTreeMap<u32, Tube>>,
641     translate_request_rx: Option<Tube>,
642 ) -> Result<()> {
643     let state = Rc::new(RefCell::new(state));
644     let ex = Executor::new().expect("Failed to create an executor");
645 
646     let (req_queue, req_evt) = queues.remove(0);
647     let req_evt = EventAsync::new(req_evt, &ex).expect("Failed to create async event for queue");
648 
649     let f_resample = async_utils::handle_irq_resample(&ex, interrupt.clone());
650     let f_kill = async_utils::await_and_exit(&ex, kill_evt);
651 
652     let request_tube = translate_request_rx
653         .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
654     let response_tubes = translate_response_senders.map(|m| {
655         m.into_iter()
656             .map(|x| {
657                 (
658                     x.0,
659                     AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
660                 )
661             })
662             .collect()
663     });
664 
665     let f_handle_translate_request =
666         sys::handle_translate_request(&ex, &state, request_tube, response_tubes);
667     let f_request = request_queue(&state, req_queue, req_evt, interrupt);
668 
669     let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
670     // Future to handle command messages from host, such as passing vfio containers.
671     let f_cmd = sys::handle_command_tube(&state, command_tube);
672 
673     let done = async {
674         select! {
675             res = f_request.fuse() => res.context("error in handling request queue"),
676             res = f_resample.fuse() => res.context("error in handle_irq_resample"),
677             res = f_kill.fuse() => res.context("error in await_and_exit"),
678             res = f_handle_translate_request.fuse() => {
679                 res.context("error in handle_translate_request")
680             }
681             res = f_cmd.fuse() => res.context("error in handling host request"),
682         }
683     };
684     match ex.run_until(done) {
685         Ok(Ok(())) => {}
686         Ok(Err(e)) => error!("Error in worker: {:#}", e),
687         Err(e) => return Err(IommuError::AsyncExec(e)),
688     }
689 
690     Ok(())
691 }
692 
693 /// Virtio device for IOMMU memory management.
694 pub struct Iommu {
695     worker_thread: Option<WorkerThread<()>>,
696     config: virtio_iommu_config,
697     avail_features: u64,
698     // Attached endpoints
699     // key: endpoint PCI address
700     // value: reference counter and MemoryMapperTrait
701     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
702     // Hot-pluggable PCI endpoints ranges
703     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
704     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
705     translate_response_senders: Option<BTreeMap<u32, Tube>>,
706     translate_request_rx: Option<Tube>,
707     iommu_device_tube: Option<Tube>,
708 }
709 
710 impl Iommu {
711     /// Create a new virtio IOMMU device.
new( base_features: u64, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, iova_max_addr: u64, hp_endpoints_ranges: Vec<RangeInclusive<u32>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, iommu_device_tube: Option<Tube>, ) -> SysResult<Iommu>712     pub fn new(
713         base_features: u64,
714         endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
715         iova_max_addr: u64,
716         hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
717         translate_response_senders: Option<BTreeMap<u32, Tube>>,
718         translate_request_rx: Option<Tube>,
719         iommu_device_tube: Option<Tube>,
720     ) -> SysResult<Iommu> {
721         let mut page_size_mask = !((pagesize() as u64) - 1);
722         for (_, container) in endpoints.iter() {
723             page_size_mask &= container
724                 .lock()
725                 .get_mask()
726                 .map_err(|_e| SysError::new(libc::EIO))?;
727         }
728 
729         if page_size_mask == 0 {
730             return Err(SysError::new(libc::EIO));
731         }
732 
733         let input_range = virtio_iommu_range_64 {
734             start: Le64::from(0),
735             end: iova_max_addr.into(),
736         };
737 
738         let config = virtio_iommu_config {
739             page_size_mask: page_size_mask.into(),
740             input_range,
741             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
742             probe_size: (IOMMU_PROBE_SIZE as u32).into(),
743             ..Default::default()
744         };
745 
746         let mut avail_features: u64 = base_features;
747         avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP
748             | 1 << VIRTIO_IOMMU_F_INPUT_RANGE
749             | 1 << VIRTIO_IOMMU_F_MMIO;
750 
751         if cfg!(any(target_arch = "x86", target_arch = "x86_64")) {
752             avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
753         }
754 
755         Ok(Iommu {
756             worker_thread: None,
757             config,
758             avail_features,
759             endpoints,
760             hp_endpoints_ranges,
761             translate_response_senders,
762             translate_request_rx,
763             iommu_device_tube,
764         })
765     }
766 }
767 
768 impl VirtioDevice for Iommu {
keep_rds(&self) -> Vec<RawDescriptor>769     fn keep_rds(&self) -> Vec<RawDescriptor> {
770         let mut rds = Vec::new();
771 
772         for (_, mapper) in self.endpoints.iter() {
773             rds.append(&mut mapper.lock().as_raw_descriptors());
774         }
775         if let Some(senders) = &self.translate_response_senders {
776             for (_, tube) in senders.iter() {
777                 rds.push(tube.as_raw_descriptor());
778             }
779         }
780         if let Some(rx) = &self.translate_request_rx {
781             rds.push(rx.as_raw_descriptor());
782         }
783 
784         if let Some(iommu_device_tube) = &self.iommu_device_tube {
785             rds.push(iommu_device_tube.as_raw_descriptor());
786         }
787 
788         rds
789     }
790 
device_type(&self) -> DeviceType791     fn device_type(&self) -> DeviceType {
792         DeviceType::Iommu
793     }
794 
queue_max_sizes(&self) -> &[u16]795     fn queue_max_sizes(&self) -> &[u16] {
796         QUEUE_SIZES
797     }
798 
features(&self) -> u64799     fn features(&self) -> u64 {
800         self.avail_features
801     }
802 
read_config(&self, offset: u64, data: &mut [u8])803     fn read_config(&self, offset: u64, data: &mut [u8]) {
804         let mut config: Vec<u8> = Vec::new();
805         config.extend_from_slice(self.config.as_bytes());
806         copy_config(data, 0, config.as_slice(), offset);
807     }
808 
activate( &mut self, mem: GuestMemory, interrupt: Interrupt, queues: Vec<(Queue, Event)>, ) -> anyhow::Result<()>809     fn activate(
810         &mut self,
811         mem: GuestMemory,
812         interrupt: Interrupt,
813         queues: Vec<(Queue, Event)>,
814     ) -> anyhow::Result<()> {
815         if queues.len() != QUEUE_SIZES.len() {
816             return Err(anyhow!(
817                 "expected {} queues, got {}",
818                 QUEUE_SIZES.len(),
819                 queues.len()
820             ));
821         }
822 
823         // The least significant bit of page_size_masks defines the page
824         // granularity of IOMMU mappings
825         let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
826         let eps = self.endpoints.clone();
827         let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
828 
829         let translate_response_senders = self.translate_response_senders.take();
830         let translate_request_rx = self.translate_request_rx.take();
831 
832         let iommu_device_tube = self
833             .iommu_device_tube
834             .take()
835             .context("failed to start virtio-iommu worker: No control tube")?;
836 
837         self.worker_thread = Some(WorkerThread::start("v_iommu", move |kill_evt| {
838             let state = State {
839                 mem,
840                 page_mask,
841                 hp_endpoints_ranges,
842                 endpoint_map: BTreeMap::new(),
843                 domain_map: BTreeMap::new(),
844                 endpoints: eps,
845                 dmabuf_mem: BTreeMap::new(),
846             };
847             let result = run(
848                 state,
849                 iommu_device_tube,
850                 queues,
851                 kill_evt,
852                 interrupt,
853                 translate_response_senders,
854                 translate_request_rx,
855             );
856             if let Err(e) = result {
857                 error!("virtio-iommu worker thread exited with error: {}", e);
858             }
859         }));
860         Ok(())
861     }
862 
863     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
generate_acpi( &mut self, pci_address: &Option<PciAddress>, mut sdts: Vec<SDT>, ) -> Option<Vec<SDT>>864     fn generate_acpi(
865         &mut self,
866         pci_address: &Option<PciAddress>,
867         mut sdts: Vec<SDT>,
868     ) -> Option<Vec<SDT>> {
869         const OEM_REVISION: u32 = 1;
870         const VIOT_REVISION: u8 = 0;
871 
872         for sdt in sdts.iter() {
873             // there should only be one VIOT table
874             if sdt.is_signature(b"VIOT") {
875                 warn!("vIOMMU: duplicate VIOT table detected");
876                 return None;
877             }
878         }
879 
880         let mut viot = SDT::new(
881             *b"VIOT",
882             acpi_tables::HEADER_LEN,
883             VIOT_REVISION,
884             *b"CROSVM",
885             *b"CROSVMDT",
886             OEM_REVISION,
887         );
888         viot.append(VirtioIommuViotHeader {
889             // # of PCI range nodes + 1 virtio-pci node
890             node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
891             node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
892             ..Default::default()
893         });
894 
895         let bdf = pci_address
896             .or_else(|| {
897                 error!("vIOMMU device has no PCI address");
898                 None
899             })?
900             .to_u32() as u16;
901         let iommu_offset = viot.len();
902 
903         viot.append(VirtioIommuViotVirtioPciNode {
904             type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
905             length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
906             bdf,
907             ..Default::default()
908         });
909 
910         for (endpoint, _) in self.endpoints.iter() {
911             viot.append(VirtioIommuViotPciRangeNode {
912                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
913                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
914                 endpoint_start: *endpoint,
915                 bdf_start: *endpoint as u16,
916                 bdf_end: *endpoint as u16,
917                 output_node: iommu_offset as u16,
918                 ..Default::default()
919             });
920         }
921 
922         for endpoints_range in self.hp_endpoints_ranges.iter() {
923             let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
924             viot.append(VirtioIommuViotPciRangeNode {
925                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
926                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
927                 endpoint_start,
928                 bdf_start: endpoint_start as u16,
929                 bdf_end: endpoint_end as u16,
930                 output_node: iommu_offset as u16,
931                 ..Default::default()
932             });
933         }
934 
935         sdts.push(viot);
936         Some(sdts)
937     }
938 }
939 
940 impl Suspendable for Iommu {}
941