• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 pub mod ipc_memory_mapper;
6 pub mod memory_mapper;
7 pub mod protocol;
8 pub(crate) mod sys;
9 
10 use std::cell::RefCell;
11 use std::collections::btree_map::Entry;
12 use std::collections::BTreeMap;
13 use std::io;
14 use std::io::Write;
15 use std::mem::size_of;
16 use std::ops::RangeInclusive;
17 use std::rc::Rc;
18 use std::result;
19 use std::sync::Arc;
20 
21 #[cfg(target_arch = "x86_64")]
22 use acpi_tables::sdt::SDT;
23 use anyhow::anyhow;
24 use anyhow::Context;
25 use base::debug;
26 use base::error;
27 use base::pagesize;
28 #[cfg(target_arch = "x86_64")]
29 use base::warn;
30 use base::AsRawDescriptor;
31 use base::Error as SysError;
32 use base::Event;
33 use base::MappedRegion;
34 use base::MemoryMapping;
35 use base::Protection;
36 use base::RawDescriptor;
37 use base::Result as SysResult;
38 use base::Tube;
39 use base::TubeError;
40 use base::WorkerThread;
41 use cros_async::AsyncError;
42 use cros_async::AsyncTube;
43 use cros_async::EventAsync;
44 use cros_async::Executor;
45 use data_model::Le64;
46 use futures::select;
47 use futures::FutureExt;
48 use remain::sorted;
49 use sync::Mutex;
50 use thiserror::Error;
51 use vm_control::VmMemoryRegionId;
52 use vm_memory::GuestAddress;
53 use vm_memory::GuestMemory;
54 use vm_memory::GuestMemoryError;
55 #[cfg(target_arch = "x86_64")]
56 use zerocopy::FromBytes;
57 #[cfg(target_arch = "x86_64")]
58 use zerocopy::Immutable;
59 use zerocopy::IntoBytes;
60 #[cfg(target_arch = "x86_64")]
61 use zerocopy::KnownLayout;
62 
63 #[cfg(target_arch = "x86_64")]
64 use crate::pci::PciAddress;
65 use crate::virtio::async_utils;
66 use crate::virtio::copy_config;
67 use crate::virtio::iommu::memory_mapper::*;
68 use crate::virtio::iommu::protocol::*;
69 use crate::virtio::DescriptorChain;
70 use crate::virtio::DeviceType;
71 use crate::virtio::Interrupt;
72 use crate::virtio::Queue;
73 use crate::virtio::Reader;
74 use crate::virtio::VirtioDevice;
75 #[cfg(target_arch = "x86_64")]
76 use crate::virtio::Writer;
77 
78 const QUEUE_SIZE: u16 = 256;
79 const NUM_QUEUES: usize = 2;
80 const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE; NUM_QUEUES];
81 
82 // Size of struct virtio_iommu_probe_property
83 #[cfg(target_arch = "x86_64")]
84 const IOMMU_PROBE_SIZE: usize = size_of::<virtio_iommu_probe_resv_mem>();
85 
86 #[cfg(target_arch = "x86_64")]
87 const VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE: u8 = 1;
88 #[cfg(target_arch = "x86_64")]
89 const VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI: u8 = 3;
90 
91 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
92 #[repr(C, packed)]
93 #[cfg(target_arch = "x86_64")]
94 struct VirtioIommuViotHeader {
95     node_count: u16,
96     node_offset: u16,
97     reserved: [u8; 8],
98 }
99 
100 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
101 #[repr(C, packed)]
102 #[cfg(target_arch = "x86_64")]
103 struct VirtioIommuViotVirtioPciNode {
104     type_: u8,
105     reserved: [u8; 1],
106     length: u16,
107     segment: u16,
108     bdf: u16,
109     reserved2: [u8; 8],
110 }
111 
112 #[derive(Copy, Clone, Debug, Default, FromBytes, Immutable, IntoBytes, KnownLayout)]
113 #[repr(C, packed)]
114 #[cfg(target_arch = "x86_64")]
115 struct VirtioIommuViotPciRangeNode {
116     type_: u8,
117     reserved: [u8; 1],
118     length: u16,
119     endpoint_start: u32,
120     segment_start: u16,
121     segment_end: u16,
122     bdf_start: u16,
123     bdf_end: u16,
124     output_node: u16,
125     reserved2: [u8; 2],
126     reserved3: [u8; 4],
127 }
128 
129 type Result<T> = result::Result<T, IommuError>;
130 
131 #[sorted]
132 #[derive(Error, Debug)]
133 pub enum IommuError {
134     #[error("async executor error: {0}")]
135     AsyncExec(AsyncError),
136     #[error("failed to create wait context: {0}")]
137     CreateWaitContext(SysError),
138     #[error("failed getting host address: {0}")]
139     GetHostAddress(GuestMemoryError),
140     #[error("failed to read from guest address: {0}")]
141     GuestMemoryRead(io::Error),
142     #[error("failed to write to guest address: {0}")]
143     GuestMemoryWrite(io::Error),
144     #[error("memory mapper failed: {0}")]
145     MemoryMapper(anyhow::Error),
146     #[error("Failed to read descriptor asynchronously: {0}")]
147     ReadAsyncDesc(AsyncError),
148     #[error("failed to read from virtio queue Event: {0}")]
149     ReadQueueEvent(SysError),
150     #[error("tube error: {0}")]
151     Tube(TubeError),
152     #[error("unexpected descriptor error")]
153     UnexpectedDescriptor,
154     #[error("failed to receive virtio-iommu control request: {0}")]
155     VirtioIOMMUReqError(TubeError),
156     #[error("failed to send virtio-iommu control response: {0}")]
157     VirtioIOMMUResponseError(TubeError),
158     #[error("failed to wait for events: {0}")]
159     WaitError(SysError),
160     #[error("write buffer length too small")]
161     WriteBufferTooSmall,
162 }
163 
164 // key: domain ID
165 // value: reference counter and MemoryMapperTrait
166 type DomainMap = BTreeMap<u32, (u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>)>;
167 
168 struct DmabufRegionEntry {
169     mmap: MemoryMapping,
170     region_id: VmMemoryRegionId,
171     size: u64,
172 }
173 
174 // Shared state for the virtio-iommu device.
175 struct State {
176     mem: GuestMemory,
177     page_mask: u64,
178     // Hot-pluggable PCI endpoints ranges
179     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
180     #[cfg_attr(windows, allow(dead_code))]
181     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
182     // All PCI endpoints that attach to certain IOMMU domain
183     // key: endpoint PCI address
184     // value: attached domain ID
185     endpoint_map: BTreeMap<u32, u32>,
186     // All attached domains
187     domain_map: DomainMap,
188     // Contains all pass-through endpoints that attach to this IOMMU device
189     // key: endpoint PCI address
190     // value: reference counter and MemoryMapperTrait
191     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
192     // Contains dmabuf regions
193     // key: guest physical address
194     dmabuf_mem: BTreeMap<u64, DmabufRegionEntry>,
195 }
196 
197 impl State {
198     // Detach the given endpoint if possible, and return whether or not the endpoint
199     // was actually detached. If a successfully detached endpoint has exported
200     // memory, returns an event that will be signaled once all exported memory is released.
201     //
202     // The device MUST ensure that after being detached from a domain, the endpoint
203     // cannot access any mapping from that domain.
204     //
205     // Currently, we only support detaching an endpoint if it is the only endpoint attached
206     // to its domain.
detach_endpoint( endpoint_map: &mut BTreeMap<u32, u32>, domain_map: &mut DomainMap, endpoint: u32, ) -> (bool, Option<EventAsync>)207     fn detach_endpoint(
208         endpoint_map: &mut BTreeMap<u32, u32>,
209         domain_map: &mut DomainMap,
210         endpoint: u32,
211     ) -> (bool, Option<EventAsync>) {
212         let mut evt = None;
213         // The endpoint has attached to an IOMMU domain
214         if let Some(attached_domain) = endpoint_map.get(&endpoint) {
215             // Remove the entry or update the domain reference count
216             if let Entry::Occupied(o) = domain_map.entry(*attached_domain) {
217                 let (refs, mapper) = o.get();
218                 if !mapper.lock().supports_detach() {
219                     return (false, None);
220                 }
221 
222                 match refs {
223                     0 => unreachable!(),
224                     1 => {
225                         evt = mapper.lock().reset_domain();
226                         o.remove();
227                     }
228                     _ => return (false, None),
229                 }
230             }
231         }
232 
233         endpoint_map.remove(&endpoint);
234         (true, evt)
235     }
236 
237     // Processes an attach request. This may require detaching the endpoint from
238     // its current endpoint before attaching it to a new endpoint. If that happens
239     // while the endpoint has exported memory, this function returns an event that
240     // will be signaled once all exported memory is released.
241     //
242     // Notes: if a VFIO group contains multiple devices, it could violate the follow
243     // requirement from the virtio IOMMU spec: If the VIRTIO_IOMMU_F_BYPASS feature
244     // is negotiated, all accesses from unattached endpoints are allowed and translated
245     // by the IOMMU using the identity function. If the feature is not negotiated, any
246     // memory access from an unattached endpoint fails.
247     //
248     // This happens after the virtio-iommu device receives a VIRTIO_IOMMU_T_ATTACH
249     // request for the first endpoint in a VFIO group, any not yet attached endpoints
250     // in the VFIO group will be able to access the domain.
251     //
252     // This violation is benign for current virtualization use cases. Since device
253     // topology in the guest matches topology in the host, the guest doesn't expect
254     // the device in the same VFIO group are isolated from each other in the first place.
process_attach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>255     fn process_attach_request(
256         &mut self,
257         reader: &mut Reader,
258         tail: &mut virtio_iommu_req_tail,
259     ) -> Result<(usize, Option<EventAsync>)> {
260         let req: virtio_iommu_req_attach =
261             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
262         let mut fault_resolved_event = None;
263 
264         // If the reserved field of an ATTACH request is not zero,
265         // the device MUST reject the request and set status to
266         // VIRTIO_IOMMU_S_INVAL.
267         if req.reserved.iter().any(|&x| x != 0) {
268             tail.status = VIRTIO_IOMMU_S_INVAL;
269             return Ok((0, None));
270         }
271 
272         let domain: u32 = req.domain.into();
273         let endpoint: u32 = req.endpoint.into();
274 
275         if let Some(mapper) = self.endpoints.get(&endpoint) {
276             // The same mapper can't be used for two domains at the same time,
277             // since that would result in conflicts/permission leaks between
278             // the two domains.
279             let mapper_id = {
280                 let m = mapper.lock();
281                 ((**m).type_id(), m.id())
282             };
283             for (other_endpoint, other_mapper) in self.endpoints.iter() {
284                 if *other_endpoint == endpoint {
285                     continue;
286                 }
287                 let other_id = {
288                     let m = other_mapper.lock();
289                     ((**m).type_id(), m.id())
290                 };
291                 if mapper_id == other_id {
292                     if !self
293                         .endpoint_map
294                         .get(other_endpoint)
295                         .map_or(true, |d| d == &domain)
296                     {
297                         tail.status = VIRTIO_IOMMU_S_UNSUPP;
298                         return Ok((0, None));
299                     }
300                 }
301             }
302 
303             // If the endpoint identified by `endpoint` is already attached
304             // to another domain, then the device SHOULD first detach it
305             // from that domain and attach it to the one identified by domain.
306             if self.endpoint_map.contains_key(&endpoint) {
307                 // In that case the device SHOULD behave as if the driver issued
308                 // a DETACH request with this endpoint, followed by the ATTACH
309                 // request. If the device cannot do so, it MUST reject the request
310                 // and set status to VIRTIO_IOMMU_S_UNSUPP.
311                 let (detached, evt) =
312                     Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
313                 if !detached {
314                     tail.status = VIRTIO_IOMMU_S_UNSUPP;
315                     return Ok((0, None));
316                 }
317                 fault_resolved_event = evt;
318             }
319 
320             let new_ref = match self.domain_map.get(&domain) {
321                 None => 1,
322                 Some(val) => val.0 + 1,
323             };
324 
325             self.endpoint_map.insert(endpoint, domain);
326             self.domain_map.insert(domain, (new_ref, mapper.clone()));
327         } else {
328             // If the endpoint identified by endpoint doesn’t exist,
329             // the device MUST reject the request and set status to
330             // VIRTIO_IOMMU_S_NOENT.
331             tail.status = VIRTIO_IOMMU_S_NOENT;
332         }
333 
334         Ok((0, fault_resolved_event))
335     }
336 
process_detach_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>337     fn process_detach_request(
338         &mut self,
339         reader: &mut Reader,
340         tail: &mut virtio_iommu_req_tail,
341     ) -> Result<(usize, Option<EventAsync>)> {
342         let req: virtio_iommu_req_detach =
343             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
344 
345         // If the endpoint identified by |req.endpoint| doesn’t exist,
346         // the device MUST reject the request and set status to
347         // VIRTIO_IOMMU_S_NOENT.
348         let endpoint: u32 = req.endpoint.into();
349         if !self.endpoints.contains_key(&endpoint) {
350             tail.status = VIRTIO_IOMMU_S_NOENT;
351             return Ok((0, None));
352         }
353 
354         let (detached, evt) =
355             Self::detach_endpoint(&mut self.endpoint_map, &mut self.domain_map, endpoint);
356         if !detached {
357             tail.status = VIRTIO_IOMMU_S_UNSUPP;
358         }
359         Ok((0, evt))
360     }
361 
process_dma_map_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>362     fn process_dma_map_request(
363         &mut self,
364         reader: &mut Reader,
365         tail: &mut virtio_iommu_req_tail,
366     ) -> Result<usize> {
367         let req: virtio_iommu_req_map = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
368 
369         let phys_start = u64::from(req.phys_start);
370         let virt_start = u64::from(req.virt_start);
371         let virt_end = u64::from(req.virt_end);
372 
373         // enforce driver requirement: virt_end MUST be strictly greater than virt_start.
374         if virt_start >= virt_end {
375             tail.status = VIRTIO_IOMMU_S_INVAL;
376             return Ok(0);
377         }
378 
379         // If virt_start, phys_start or (virt_end + 1) is not aligned
380         // on the page granularity, the device SHOULD reject the
381         // request and set status to VIRTIO_IOMMU_S_RANGE
382         if self.page_mask & phys_start != 0
383             || self.page_mask & virt_start != 0
384             || self.page_mask & (virt_end + 1) != 0
385         {
386             tail.status = VIRTIO_IOMMU_S_RANGE;
387             return Ok(0);
388         }
389 
390         // If the device doesn’t recognize a flags bit, it MUST reject
391         // the request and set status to VIRTIO_IOMMU_S_INVAL.
392         if u32::from(req.flags) & !VIRTIO_IOMMU_MAP_F_MASK != 0 {
393             tail.status = VIRTIO_IOMMU_S_INVAL;
394             return Ok(0);
395         }
396 
397         let domain: u32 = req.domain.into();
398         if !self.domain_map.contains_key(&domain) {
399             // If domain does not exist, the device SHOULD reject
400             // the request and set status to VIRTIO_IOMMU_S_NOENT.
401             tail.status = VIRTIO_IOMMU_S_NOENT;
402             return Ok(0);
403         }
404 
405         // The device MUST NOT allow writes to a range mapped
406         // without the VIRTIO_IOMMU_MAP_F_WRITE flag.
407         let write_en = u32::from(req.flags) & VIRTIO_IOMMU_MAP_F_WRITE != 0;
408 
409         if let Some(mapper) = self.domain_map.get(&domain) {
410             let gpa = phys_start;
411             let iova = virt_start;
412             let Some(size) = u64::checked_add(virt_end - virt_start, 1) else {
413                 // implementation doesn't support unlikely request for size == U64::MAX+1
414                 tail.status = VIRTIO_IOMMU_S_DEVERR;
415                 return Ok(0);
416             };
417 
418             let dmabuf_map =
419                 self.dmabuf_mem
420                     .range(..=gpa)
421                     .next_back()
422                     .and_then(|(base_gpa, region)| {
423                         if gpa + size <= base_gpa + region.size {
424                             let offset = gpa - base_gpa;
425                             Some(region.mmap.as_ptr() as u64 + offset)
426                         } else {
427                             None
428                         }
429                     });
430 
431             let prot = match write_en {
432                 true => Protection::read_write(),
433                 false => Protection::read(),
434             };
435 
436             let vfio_map_result = match dmabuf_map {
437                 // SAFETY:
438                 // Safe because [dmabuf_map, dmabuf_map + size) refers to an external mmap'ed
439                 // region.
440                 Some(dmabuf_map) => unsafe {
441                     mapper.1.lock().vfio_dma_map(iova, dmabuf_map, size, prot)
442                 },
443                 None => mapper.1.lock().add_map(MappingInfo {
444                     iova,
445                     gpa: GuestAddress(gpa),
446                     size,
447                     prot,
448                 }),
449             };
450 
451             match vfio_map_result {
452                 Ok(AddMapResult::Ok) => (),
453                 Ok(AddMapResult::OverlapFailure) => {
454                     // If a mapping already exists in the requested range,
455                     // the device SHOULD reject the request and set status
456                     // to VIRTIO_IOMMU_S_INVAL.
457                     tail.status = VIRTIO_IOMMU_S_INVAL;
458                 }
459                 Err(e) => return Err(IommuError::MemoryMapper(e)),
460             }
461         }
462 
463         Ok(0)
464     }
465 
process_dma_unmap_request( &mut self, reader: &mut Reader, tail: &mut virtio_iommu_req_tail, ) -> Result<(usize, Option<EventAsync>)>466     fn process_dma_unmap_request(
467         &mut self,
468         reader: &mut Reader,
469         tail: &mut virtio_iommu_req_tail,
470     ) -> Result<(usize, Option<EventAsync>)> {
471         let req: virtio_iommu_req_unmap = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
472 
473         let domain: u32 = req.domain.into();
474         let fault_resolved_event = if let Some(mapper) = self.domain_map.get(&domain) {
475             let size = u64::from(req.virt_end) - u64::from(req.virt_start) + 1;
476             let res = mapper
477                 .1
478                 .lock()
479                 .remove_map(u64::from(req.virt_start), size)
480                 .map_err(IommuError::MemoryMapper)?;
481             match res {
482                 RemoveMapResult::Success(evt) => evt,
483                 RemoveMapResult::OverlapFailure => {
484                     // If a mapping affected by the range is not covered in its entirety by the
485                     // range (the UNMAP request would split the mapping), then the device SHOULD
486                     // set the request `status` to VIRTIO_IOMMU_S_RANGE, and SHOULD NOT remove
487                     // any mapping.
488                     tail.status = VIRTIO_IOMMU_S_RANGE;
489                     None
490                 }
491             }
492         } else {
493             // If domain does not exist, the device SHOULD set the
494             // request status to VIRTIO_IOMMU_S_NOENT
495             tail.status = VIRTIO_IOMMU_S_NOENT;
496             None
497         };
498 
499         Ok((0, fault_resolved_event))
500     }
501 
502     #[cfg(target_arch = "x86_64")]
process_probe_request( &mut self, reader: &mut Reader, writer: &mut Writer, tail: &mut virtio_iommu_req_tail, ) -> Result<usize>503     fn process_probe_request(
504         &mut self,
505         reader: &mut Reader,
506         writer: &mut Writer,
507         tail: &mut virtio_iommu_req_tail,
508     ) -> Result<usize> {
509         let req: virtio_iommu_req_probe = reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
510         let endpoint: u32 = req.endpoint.into();
511 
512         // If the endpoint identified by endpoint doesn’t exist,
513         // then the device SHOULD reject the request and set status
514         // to VIRTIO_IOMMU_S_NOENT.
515         if !self.endpoints.contains_key(&endpoint) {
516             tail.status = VIRTIO_IOMMU_S_NOENT;
517         }
518 
519         let properties_size = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
520 
521         // It's OK if properties_size is larger than probe_size
522         // We are good even if properties_size is 0
523         if properties_size < IOMMU_PROBE_SIZE {
524             // If the properties list is smaller than probe_size, the device
525             // SHOULD NOT write any property. It SHOULD reject the request
526             // and set status to VIRTIO_IOMMU_S_INVAL.
527             tail.status = VIRTIO_IOMMU_S_INVAL;
528         } else if tail.status == VIRTIO_IOMMU_S_OK {
529             const VIRTIO_IOMMU_PROBE_T_RESV_MEM: u16 = 1;
530             const VIRTIO_IOMMU_RESV_MEM_T_MSI: u8 = 1;
531             const PROBE_PROPERTY_SIZE: u16 = 4;
532             const X86_MSI_IOVA_START: u64 = 0xfee0_0000;
533             const X86_MSI_IOVA_END: u64 = 0xfeef_ffff;
534 
535             let properties = virtio_iommu_probe_resv_mem {
536                 head: virtio_iommu_probe_property {
537                     type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM.into(),
538                     length: (IOMMU_PROBE_SIZE as u16 - PROBE_PROPERTY_SIZE).into(),
539                 },
540                 subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
541                 start: X86_MSI_IOVA_START.into(),
542                 end: X86_MSI_IOVA_END.into(),
543                 ..Default::default()
544             };
545             writer
546                 .write_all(properties.as_bytes())
547                 .map_err(IommuError::GuestMemoryWrite)?;
548         }
549 
550         // If the device doesn’t fill all probe_size bytes with properties,
551         // it SHOULD fill the remaining bytes of properties with zeroes.
552         let remaining_bytes = writer.available_bytes() - size_of::<virtio_iommu_req_tail>();
553 
554         if remaining_bytes > 0 {
555             let buffer: Vec<u8> = vec![0; remaining_bytes];
556             writer
557                 .write_all(buffer.as_slice())
558                 .map_err(IommuError::GuestMemoryWrite)?;
559         }
560 
561         Ok(properties_size)
562     }
563 
execute_request( &mut self, avail_desc: &mut DescriptorChain, ) -> Result<(usize, Option<EventAsync>)>564     fn execute_request(
565         &mut self,
566         avail_desc: &mut DescriptorChain,
567     ) -> Result<(usize, Option<EventAsync>)> {
568         let reader = &mut avail_desc.reader;
569         let writer = &mut avail_desc.writer;
570 
571         // at least we need space to write VirtioIommuReqTail
572         if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
573             return Err(IommuError::WriteBufferTooSmall);
574         }
575 
576         let req_head: virtio_iommu_req_head =
577             reader.read_obj().map_err(IommuError::GuestMemoryRead)?;
578 
579         let mut tail = virtio_iommu_req_tail {
580             status: VIRTIO_IOMMU_S_OK,
581             ..Default::default()
582         };
583 
584         let (reply_len, fault_resolved_event) = match req_head.type_ {
585             VIRTIO_IOMMU_T_ATTACH => self.process_attach_request(reader, &mut tail)?,
586             VIRTIO_IOMMU_T_DETACH => self.process_detach_request(reader, &mut tail)?,
587             VIRTIO_IOMMU_T_MAP => (self.process_dma_map_request(reader, &mut tail)?, None),
588             VIRTIO_IOMMU_T_UNMAP => self.process_dma_unmap_request(reader, &mut tail)?,
589             #[cfg(target_arch = "x86_64")]
590             VIRTIO_IOMMU_T_PROBE => (self.process_probe_request(reader, writer, &mut tail)?, None),
591             _ => return Err(IommuError::UnexpectedDescriptor),
592         };
593 
594         writer
595             .write_all(tail.as_bytes())
596             .map_err(IommuError::GuestMemoryWrite)?;
597         Ok((
598             reply_len + size_of::<virtio_iommu_req_tail>(),
599             fault_resolved_event,
600         ))
601     }
602 }
603 
request_queue( state: &Rc<RefCell<State>>, mut queue: Queue, mut queue_event: EventAsync, ) -> Result<()>604 async fn request_queue(
605     state: &Rc<RefCell<State>>,
606     mut queue: Queue,
607     mut queue_event: EventAsync,
608 ) -> Result<()> {
609     loop {
610         let mut avail_desc = queue
611             .next_async(&mut queue_event)
612             .await
613             .map_err(IommuError::ReadAsyncDesc)?;
614 
615         let (len, fault_resolved_event) = match state.borrow_mut().execute_request(&mut avail_desc)
616         {
617             Ok(res) => res,
618             Err(e) => {
619                 error!("execute_request failed: {}", e);
620 
621                 // If a request type is not recognized, the device SHOULD NOT write
622                 // the buffer and SHOULD set the used length to zero
623                 (0, None)
624             }
625         };
626 
627         if let Some(fault_resolved_event) = fault_resolved_event {
628             debug!("waiting for iommu fault resolution");
629             fault_resolved_event
630                 .next_val()
631                 .await
632                 .expect("failed waiting for fault");
633             debug!("iommu fault resolved");
634         }
635 
636         queue.add_used(avail_desc, len as u32);
637         queue.trigger_interrupt();
638     }
639 }
640 
run( state: State, iommu_device_tube: Tube, mut queues: BTreeMap<usize, Queue>, kill_evt: Event, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, ) -> Result<()>641 fn run(
642     state: State,
643     iommu_device_tube: Tube,
644     mut queues: BTreeMap<usize, Queue>,
645     kill_evt: Event,
646     translate_response_senders: Option<BTreeMap<u32, Tube>>,
647     translate_request_rx: Option<Tube>,
648 ) -> Result<()> {
649     let state = Rc::new(RefCell::new(state));
650     let ex = Executor::new().expect("Failed to create an executor");
651 
652     let req_queue = queues.remove(&0).unwrap();
653     let req_evt = req_queue
654         .event()
655         .try_clone()
656         .expect("Failed to clone queue event");
657     let req_evt = EventAsync::new(req_evt, &ex).expect("Failed to create async event for queue");
658 
659     let f_kill = async_utils::await_and_exit(&ex, kill_evt);
660 
661     let request_tube = translate_request_rx
662         .map(|t| AsyncTube::new(&ex, t).expect("Failed to create async tube for rx"));
663     let response_tubes = translate_response_senders.map(|m| {
664         m.into_iter()
665             .map(|x| {
666                 (
667                     x.0,
668                     AsyncTube::new(&ex, x.1).expect("Failed to create async tube"),
669                 )
670             })
671             .collect()
672     });
673 
674     let f_handle_translate_request =
675         sys::handle_translate_request(&ex, &state, request_tube, response_tubes);
676     let f_request = request_queue(&state, req_queue, req_evt);
677 
678     let command_tube = AsyncTube::new(&ex, iommu_device_tube).unwrap();
679     // Future to handle command messages from host, such as passing vfio containers.
680     let f_cmd = sys::handle_command_tube(&state, command_tube);
681 
682     let done = async {
683         select! {
684             res = f_request.fuse() => res.context("error in handling request queue"),
685             res = f_kill.fuse() => res.context("error in await_and_exit"),
686             res = f_handle_translate_request.fuse() => {
687                 res.context("error in handle_translate_request")
688             }
689             res = f_cmd.fuse() => res.context("error in handling host request"),
690         }
691     };
692     match ex.run_until(done) {
693         Ok(Ok(())) => {}
694         Ok(Err(e)) => error!("Error in worker: {:#}", e),
695         Err(e) => return Err(IommuError::AsyncExec(e)),
696     }
697 
698     Ok(())
699 }
700 
701 /// Virtio device for IOMMU memory management.
702 pub struct Iommu {
703     worker_thread: Option<WorkerThread<()>>,
704     config: virtio_iommu_config,
705     avail_features: u64,
706     // Attached endpoints
707     // key: endpoint PCI address
708     // value: reference counter and MemoryMapperTrait
709     endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
710     // Hot-pluggable PCI endpoints ranges
711     // RangeInclusive: (start endpoint PCI address .. =end endpoint PCI address)
712     hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
713     translate_response_senders: Option<BTreeMap<u32, Tube>>,
714     translate_request_rx: Option<Tube>,
715     iommu_device_tube: Option<Tube>,
716 }
717 
718 impl Iommu {
719     /// Create a new virtio IOMMU device.
new( base_features: u64, endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>, iova_max_addr: u64, hp_endpoints_ranges: Vec<RangeInclusive<u32>>, translate_response_senders: Option<BTreeMap<u32, Tube>>, translate_request_rx: Option<Tube>, iommu_device_tube: Option<Tube>, ) -> SysResult<Iommu>720     pub fn new(
721         base_features: u64,
722         endpoints: BTreeMap<u32, Arc<Mutex<Box<dyn MemoryMapperTrait>>>>,
723         iova_max_addr: u64,
724         hp_endpoints_ranges: Vec<RangeInclusive<u32>>,
725         translate_response_senders: Option<BTreeMap<u32, Tube>>,
726         translate_request_rx: Option<Tube>,
727         iommu_device_tube: Option<Tube>,
728     ) -> SysResult<Iommu> {
729         let mut page_size_mask = !((pagesize() as u64) - 1);
730         for (_, container) in endpoints.iter() {
731             page_size_mask &= container
732                 .lock()
733                 .get_mask()
734                 .map_err(|_e| SysError::new(libc::EIO))?;
735         }
736 
737         if page_size_mask == 0 {
738             return Err(SysError::new(libc::EIO));
739         }
740 
741         let input_range = virtio_iommu_range_64 {
742             start: Le64::from(0),
743             end: iova_max_addr.into(),
744         };
745 
746         let config = virtio_iommu_config {
747             page_size_mask: page_size_mask.into(),
748             input_range,
749             #[cfg(target_arch = "x86_64")]
750             probe_size: (IOMMU_PROBE_SIZE as u32).into(),
751             ..Default::default()
752         };
753 
754         let mut avail_features: u64 = base_features;
755         avail_features |= 1 << VIRTIO_IOMMU_F_MAP_UNMAP
756             | 1 << VIRTIO_IOMMU_F_INPUT_RANGE
757             | 1 << VIRTIO_IOMMU_F_MMIO;
758 
759         if cfg!(target_arch = "x86_64") {
760             avail_features |= 1 << VIRTIO_IOMMU_F_PROBE;
761         }
762 
763         Ok(Iommu {
764             worker_thread: None,
765             config,
766             avail_features,
767             endpoints,
768             hp_endpoints_ranges,
769             translate_response_senders,
770             translate_request_rx,
771             iommu_device_tube,
772         })
773     }
774 }
775 
776 impl VirtioDevice for Iommu {
keep_rds(&self) -> Vec<RawDescriptor>777     fn keep_rds(&self) -> Vec<RawDescriptor> {
778         let mut rds = Vec::new();
779 
780         for (_, mapper) in self.endpoints.iter() {
781             rds.append(&mut mapper.lock().as_raw_descriptors());
782         }
783         if let Some(senders) = &self.translate_response_senders {
784             for (_, tube) in senders.iter() {
785                 rds.push(tube.as_raw_descriptor());
786             }
787         }
788         if let Some(rx) = &self.translate_request_rx {
789             rds.push(rx.as_raw_descriptor());
790         }
791 
792         if let Some(iommu_device_tube) = &self.iommu_device_tube {
793             rds.push(iommu_device_tube.as_raw_descriptor());
794         }
795 
796         rds
797     }
798 
device_type(&self) -> DeviceType799     fn device_type(&self) -> DeviceType {
800         DeviceType::Iommu
801     }
802 
queue_max_sizes(&self) -> &[u16]803     fn queue_max_sizes(&self) -> &[u16] {
804         QUEUE_SIZES
805     }
806 
features(&self) -> u64807     fn features(&self) -> u64 {
808         self.avail_features
809     }
810 
read_config(&self, offset: u64, data: &mut [u8])811     fn read_config(&self, offset: u64, data: &mut [u8]) {
812         let mut config: Vec<u8> = Vec::new();
813         config.extend_from_slice(self.config.as_bytes());
814         copy_config(data, 0, config.as_slice(), offset);
815     }
816 
activate( &mut self, mem: GuestMemory, _interrupt: Interrupt, queues: BTreeMap<usize, Queue>, ) -> anyhow::Result<()>817     fn activate(
818         &mut self,
819         mem: GuestMemory,
820         _interrupt: Interrupt,
821         queues: BTreeMap<usize, Queue>,
822     ) -> anyhow::Result<()> {
823         if queues.len() != QUEUE_SIZES.len() {
824             return Err(anyhow!(
825                 "expected {} queues, got {}",
826                 QUEUE_SIZES.len(),
827                 queues.len()
828             ));
829         }
830 
831         // The least significant bit of page_size_masks defines the page
832         // granularity of IOMMU mappings
833         let page_mask = (1u64 << u64::from(self.config.page_size_mask).trailing_zeros()) - 1;
834         let eps = self.endpoints.clone();
835         let hp_endpoints_ranges = self.hp_endpoints_ranges.to_owned();
836 
837         let translate_response_senders = self.translate_response_senders.take();
838         let translate_request_rx = self.translate_request_rx.take();
839 
840         let iommu_device_tube = self
841             .iommu_device_tube
842             .take()
843             .context("failed to start virtio-iommu worker: No control tube")?;
844 
845         self.worker_thread = Some(WorkerThread::start("v_iommu", move |kill_evt| {
846             let state = State {
847                 mem,
848                 page_mask,
849                 hp_endpoints_ranges,
850                 endpoint_map: BTreeMap::new(),
851                 domain_map: BTreeMap::new(),
852                 endpoints: eps,
853                 dmabuf_mem: BTreeMap::new(),
854             };
855             let result = run(
856                 state,
857                 iommu_device_tube,
858                 queues,
859                 kill_evt,
860                 translate_response_senders,
861                 translate_request_rx,
862             );
863             if let Err(e) = result {
864                 error!("virtio-iommu worker thread exited with error: {}", e);
865             }
866         }));
867         Ok(())
868     }
869 
870     #[cfg(target_arch = "x86_64")]
generate_acpi(&mut self, pci_address: PciAddress, sdts: &mut Vec<SDT>)871     fn generate_acpi(&mut self, pci_address: PciAddress, sdts: &mut Vec<SDT>) {
872         const OEM_REVISION: u32 = 1;
873         const VIOT_REVISION: u8 = 0;
874 
875         for sdt in sdts.iter() {
876             // there should only be one VIOT table
877             if sdt.is_signature(b"VIOT") {
878                 warn!("vIOMMU: duplicate VIOT table detected");
879                 return;
880             }
881         }
882 
883         let mut viot = SDT::new(
884             *b"VIOT",
885             acpi_tables::HEADER_LEN,
886             VIOT_REVISION,
887             *b"CROSVM",
888             *b"CROSVMDT",
889             OEM_REVISION,
890         );
891         viot.append(VirtioIommuViotHeader {
892             // # of PCI range nodes + 1 virtio-pci node
893             node_count: (self.endpoints.len() + self.hp_endpoints_ranges.len() + 1) as u16,
894             node_offset: (viot.len() + std::mem::size_of::<VirtioIommuViotHeader>()) as u16,
895             ..Default::default()
896         });
897 
898         let bdf = pci_address.to_u32() as u16;
899         let iommu_offset = viot.len();
900 
901         viot.append(VirtioIommuViotVirtioPciNode {
902             type_: VIRTIO_IOMMU_VIOT_NODE_VIRTIO_IOMMU_PCI,
903             length: size_of::<VirtioIommuViotVirtioPciNode>() as u16,
904             bdf,
905             ..Default::default()
906         });
907 
908         for (endpoint, _) in self.endpoints.iter() {
909             viot.append(VirtioIommuViotPciRangeNode {
910                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
911                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
912                 endpoint_start: *endpoint,
913                 bdf_start: *endpoint as u16,
914                 bdf_end: *endpoint as u16,
915                 output_node: iommu_offset as u16,
916                 ..Default::default()
917             });
918         }
919 
920         for endpoints_range in self.hp_endpoints_ranges.iter() {
921             let (endpoint_start, endpoint_end) = endpoints_range.clone().into_inner();
922             viot.append(VirtioIommuViotPciRangeNode {
923                 type_: VIRTIO_IOMMU_VIOT_NODE_PCI_RANGE,
924                 length: size_of::<VirtioIommuViotPciRangeNode>() as u16,
925                 endpoint_start,
926                 bdf_start: endpoint_start as u16,
927                 bdf_end: endpoint_end as u16,
928                 output_node: iommu_offset as u16,
929                 ..Default::default()
930             });
931         }
932 
933         sdts.push(viot);
934     }
935 }
936