• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::borrow::Cow;
6 use std::cmp;
7 use std::convert::TryInto;
8 use std::io::{self, Write};
9 use std::iter::FromIterator;
10 use std::marker::PhantomData;
11 use std::ptr::copy_nonoverlapping;
12 use std::result;
13 use std::sync::Arc;
14 
15 use anyhow::Context;
16 use base::{FileReadWriteAtVolatile, FileReadWriteVolatile};
17 use cros_async::MemRegion;
18 use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice};
19 use disk::AsyncDisk;
20 use remain::sorted;
21 use smallvec::SmallVec;
22 use thiserror::Error;
23 use vm_memory::{GuestAddress, GuestMemory};
24 
25 use super::DescriptorChain;
26 
27 #[sorted]
28 #[derive(Error, Debug)]
29 pub enum Error {
30     #[error("the combined length of all the buffers in a `DescriptorChain` would overflow")]
31     DescriptorChainOverflow,
32     #[error("descriptor guest memory error: {0}")]
33     GuestMemoryError(vm_memory::GuestMemoryError),
34     #[error("invalid descriptor chain: {0:#}")]
35     InvalidChain(anyhow::Error),
36     #[error("descriptor I/O error: {0}")]
37     IoError(io::Error),
38     #[error("`DescriptorChain` split is out of bounds: {0}")]
39     SplitOutOfBounds(usize),
40     #[error("volatile memory error: {0}")]
41     VolatileMemoryError(VolatileMemoryError),
42 }
43 
44 pub type Result<T> = result::Result<T, Error>;
45 
46 #[derive(Clone)]
47 struct DescriptorChainRegions {
48     regions: SmallVec<[MemRegion; 16]>,
49     current: usize,
50     bytes_consumed: usize,
51 }
52 
53 impl DescriptorChainRegions {
available_bytes(&self) -> usize54     fn available_bytes(&self) -> usize {
55         // This is guaranteed not to overflow because the total length of the chain
56         // is checked during all creations of `DescriptorChainRegions` (see
57         // `Reader::new()` and `Writer::new()`).
58         self.get_remaining_regions()
59             .iter()
60             .fold(0usize, |count, region| count + region.len)
61     }
62 
bytes_consumed(&self) -> usize63     fn bytes_consumed(&self) -> usize {
64         self.bytes_consumed
65     }
66 
67     /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
68     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
69     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
70     /// to `consume` will return the same data.
get_remaining_regions(&self) -> &[MemRegion]71     fn get_remaining_regions(&self) -> &[MemRegion] {
72         &self.regions[self.current..]
73     }
74 
75     /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
76     /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
77     /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
78     /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>79     fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
80         self.get_remaining_regions()
81             .iter()
82             .filter_map(|region| {
83                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
84                     .ok()
85             })
86             .collect()
87     }
88 
89     /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is
90     /// not greater than `count`. The combined length of the returned iovecs may be less than
91     /// `count` but will always be greater than 0 as long as there is still space left in the
92     /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]>93     fn get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]> {
94         let regions = self.get_remaining_regions();
95         let mut region_count = 0;
96         let mut rem = count;
97         for region in regions {
98             if rem < region.len {
99                 break;
100             }
101 
102             region_count += 1;
103             rem -= region.len;
104         }
105 
106         // Special case where the number of bytes to be copied is smaller than the `size()` of the
107         // first regions.
108         if region_count == 0 && !regions.is_empty() && count > 0 {
109             debug_assert!(count < regions[0].len);
110             // Safe because we know that count is smaller than the length of the first slice.
111             Cow::Owned(vec![MemRegion {
112                 offset: regions[0].offset,
113                 len: count,
114             }])
115         } else {
116             Cow::Borrowed(&regions[..region_count])
117         }
118     }
119 
120     /// Like 'get_remaining_with_count' except convert the offsets to volatile slices in the
121     /// 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>122     fn get_remaining_with_count<'mem>(
123         &self,
124         mem: &'mem GuestMemory,
125         count: usize,
126     ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
127         self.get_remaining_regions_with_count(count)
128             .iter()
129             .filter_map(|region| {
130                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
131                     .ok()
132             })
133             .collect()
134     }
135 
136     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
137     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)138     fn consume(&mut self, mut count: usize) {
139         // The implementation is adapted from `IoSlice::advance` in libstd. We can't use
140         // `get_remaining` here because then the compiler complains that `self.current` is already
141         // borrowed and doesn't allow us to modify it.  We also need to borrow the iovecs mutably.
142         let current = self.current;
143         for region in &mut self.regions[current..] {
144             if count == 0 {
145                 break;
146             }
147 
148             let consumed = if count < region.len {
149                 // Safe because we know that the iovec pointed to valid memory and we are adding a
150                 // value that is smaller than the length of the memory.
151                 *region = MemRegion {
152                     offset: region.offset + count as u64,
153                     len: region.len - count,
154                 };
155                 count
156             } else {
157                 self.current += 1;
158                 region.len
159             };
160 
161             // This shouldn't overflow because `consumed <= buf.size()` and we already verified
162             // that adding all `buf.size()` values will not overflow when the Reader/Writer was
163             // constructed.
164             self.bytes_consumed += consumed;
165             count -= consumed;
166         }
167     }
168 
split_at(&mut self, offset: usize) -> DescriptorChainRegions169     fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
170         let mut other = self.clone();
171         other.consume(offset);
172         other.bytes_consumed = 0;
173 
174         let mut rem = offset;
175         let mut end = self.current;
176         for region in &mut self.regions[self.current..] {
177             if rem <= region.len {
178                 region.len = rem;
179                 break;
180             }
181 
182             end += 1;
183             rem -= region.len;
184         }
185 
186         self.regions.truncate(end + 1);
187 
188         other
189     }
190 }
191 
192 /// Provides high-level interface over the sequence of memory regions
193 /// defined by readable descriptors in the descriptor chain.
194 ///
195 /// Note that virtio spec requires driver to place any device-writable
196 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
197 /// Reader will skip iterating over descriptor chain when first writable
198 /// descriptor is encountered.
199 #[derive(Clone)]
200 pub struct Reader {
201     mem: GuestMemory,
202     regions: DescriptorChainRegions,
203 }
204 
205 /// An iterator over `DataInit` objects on readable descriptors in the descriptor chain.
206 pub struct ReaderIterator<'a, T: DataInit> {
207     reader: &'a mut Reader,
208     phantom: PhantomData<T>,
209 }
210 
211 impl<'a, T: DataInit> Iterator for ReaderIterator<'a, T> {
212     type Item = io::Result<T>;
213 
next(&mut self) -> Option<io::Result<T>>214     fn next(&mut self) -> Option<io::Result<T>> {
215         if self.reader.available_bytes() == 0 {
216             None
217         } else {
218             Some(self.reader.read_obj())
219         }
220     }
221 }
222 
223 /// Get all the mem regions from a `DescriptorChain` iterator, regardless if the `DescriptorChain`
224 /// contains GPAs (guest physical address), or IOVAs (io virtual address). IOVAs will
225 /// be translated to GPAs via IOMMU.
get_mem_regions<I>(mem: &GuestMemory, vals: I) -> Result<SmallVec<[MemRegion; 16]>> where I: Iterator<Item = DescriptorChain>,226 pub fn get_mem_regions<I>(mem: &GuestMemory, vals: I) -> Result<SmallVec<[MemRegion; 16]>>
227 where
228     I: Iterator<Item = DescriptorChain>,
229 {
230     let mut total_len: usize = 0;
231     let mut regions = SmallVec::new();
232 
233     // TODO(jstaron): Update this code to take the indirect descriptors into account.
234     for desc in vals {
235         // Verify that summing the descriptor sizes does not overflow.
236         // This can happen if a driver tricks a device into reading/writing more data than
237         // fits in a `usize`.
238         total_len = total_len
239             .checked_add(desc.len as usize)
240             .ok_or(Error::DescriptorChainOverflow)?;
241 
242         for r in desc.get_mem_regions().map_err(Error::InvalidChain)? {
243             // Check that all the regions are totally contained in GuestMemory.
244             mem.get_slice_at_addr(r.gpa, r.len.try_into().expect("u32 doesn't fit in usize"))
245                 .map_err(Error::GuestMemoryError)?;
246 
247             regions.push(MemRegion {
248                 offset: r.gpa.offset(),
249                 len: r.len.try_into().expect("u32 doesn't fit in usize"),
250             });
251         }
252     }
253 
254     Ok(regions)
255 }
256 
257 impl Reader {
258     /// Construct a new Reader wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader>259     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader> {
260         let regions = get_mem_regions(&mem, desc_chain.into_iter().readable())?;
261         Ok(Reader {
262             mem,
263             regions: DescriptorChainRegions {
264                 regions,
265                 current: 0,
266                 bytes_consumed: 0,
267             },
268         })
269     }
270 
271     /// Reads an object from the descriptor chain buffer.
read_obj<T: DataInit>(&mut self) -> io::Result<T>272     pub fn read_obj<T: DataInit>(&mut self) -> io::Result<T> {
273         T::from_reader(self)
274     }
275 
276     /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
277     /// them as a collection. Returns an error if the size of the remaining data is indivisible by
278     /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C279     pub fn collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C {
280         self.iter().collect()
281     }
282 
283     /// Creates an iterator for sequentially reading `DataInit` objects from the `Reader`.
284     /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
285     /// doesn't require the objects to be stored in a separate collection.
iter<T: DataInit>(&mut self) -> ReaderIterator<T>286     pub fn iter<T: DataInit>(&mut self) -> ReaderIterator<T> {
287         ReaderIterator {
288             reader: self,
289             phantom: PhantomData,
290         }
291     }
292 
293     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
294     /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize295     pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
296         let mut read = 0usize;
297         let mut dst = slice;
298         for src in self.get_remaining() {
299             src.copy_to_volatile_slice(dst);
300             let copied = std::cmp::min(src.size(), dst.size());
301             read += copied;
302             dst = match dst.offset(copied) {
303                 Ok(v) => v,
304                 Err(_) => break, // The slice is fully consumed
305             };
306         }
307         self.regions.consume(read);
308         read
309     }
310 
311     /// Reads data from the descriptor chain buffer into a file descriptor.
312     /// Returns the number of bytes read from the descriptor chain buffer.
313     /// The number of bytes read can be less than `count` if there isn't
314     /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>315     pub fn read_to<F: FileReadWriteVolatile>(
316         &mut self,
317         mut dst: F,
318         count: usize,
319     ) -> io::Result<usize> {
320         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
321         let written = dst.write_vectored_volatile(&iovs[..])?;
322         self.regions.consume(written);
323         Ok(written)
324     }
325 
326     /// Reads data from the descriptor chain buffer into a File at offset `off`.
327     /// Returns the number of bytes read from the descriptor chain buffer.
328     /// The number of bytes read can be less than `count` if there isn't
329     /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, count: usize, off: u64, ) -> io::Result<usize>330     pub fn read_to_at<F: FileReadWriteAtVolatile>(
331         &mut self,
332         mut dst: F,
333         count: usize,
334         off: u64,
335     ) -> io::Result<usize> {
336         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
337         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
338         self.regions.consume(written);
339         Ok(written)
340     }
341 
342     /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
343     /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>344     pub fn read_exact_to<F: FileReadWriteVolatile>(
345         &mut self,
346         mut dst: F,
347         mut count: usize,
348     ) -> io::Result<()> {
349         while count > 0 {
350             match self.read_to(&mut dst, count) {
351                 Ok(0) => {
352                     return Err(io::Error::new(
353                         io::ErrorKind::UnexpectedEof,
354                         "failed to fill whole buffer",
355                     ))
356                 }
357                 Ok(n) => count -= n,
358                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
359                 Err(e) => return Err(e),
360             }
361         }
362 
363         Ok(())
364     }
365 
366     /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
367     /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, mut count: usize, mut off: u64, ) -> io::Result<()>368     pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
369         &mut self,
370         mut dst: F,
371         mut count: usize,
372         mut off: u64,
373     ) -> io::Result<()> {
374         while count > 0 {
375             match self.read_to_at(&mut dst, count, off) {
376                 Ok(0) => {
377                     return Err(io::Error::new(
378                         io::ErrorKind::UnexpectedEof,
379                         "failed to fill whole buffer",
380                     ))
381                 }
382                 Ok(n) => {
383                     count -= n;
384                     off += n as u64;
385                 }
386                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
387                 Err(e) => return Err(e),
388             }
389         }
390 
391         Ok(())
392     }
393 
394     /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
395     /// Returns the number of bytes read from the descriptor chain buffer.
396     /// The number of bytes read can be less than `count` if there isn't
397     /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>398     pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
399         &mut self,
400         dst: &F,
401         count: usize,
402         off: u64,
403     ) -> disk::Result<usize> {
404         let mem_regions = self.regions.get_remaining_regions_with_count(count);
405         let written = dst
406             .write_from_mem(off, Arc::new(self.mem.clone()), &mem_regions)
407             .await?;
408         self.regions.consume(written);
409         Ok(written)
410     }
411 
412     /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
413     /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>414     pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
415         &mut self,
416         dst: &F,
417         mut count: usize,
418         mut off: u64,
419     ) -> disk::Result<()> {
420         while count > 0 {
421             let nread = self.read_to_at_fut(dst, count, off).await?;
422             if nread == 0 {
423                 return Err(disk::Error::ReadingData(io::Error::new(
424                     io::ErrorKind::UnexpectedEof,
425                     "failed to write whole buffer",
426                 )));
427             }
428             count -= nread;
429             off += nread as u64;
430         }
431 
432         Ok(())
433     }
434 
435     /// Returns number of bytes available for reading.  May return an error if the combined
436     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize437     pub fn available_bytes(&self) -> usize {
438         self.regions.available_bytes()
439     }
440 
441     /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize442     pub fn bytes_read(&self) -> usize {
443         self.regions.bytes_consumed()
444     }
445 
446     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
447     /// Calling this method does not actually consume any data from the `Reader` and callers should
448     /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>449     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
450         self.regions.get_remaining(&self.mem)
451     }
452 
453     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
454     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)455     pub fn consume(&mut self, amt: usize) {
456         self.regions.consume(amt)
457     }
458 
459     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
460     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
461     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
462     /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader463     pub fn split_at(&mut self, offset: usize) -> Reader {
464         Reader {
465             mem: self.mem.clone(),
466             regions: self.regions.split_at(offset),
467         }
468     }
469 }
470 
471 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>472     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
473         let mut rem = buf;
474         let mut total = 0;
475         for b in self.regions.get_remaining(&self.mem) {
476             if rem.is_empty() {
477                 break;
478             }
479 
480             let count = cmp::min(rem.len(), b.size());
481 
482             // Safe because we have already verified that `b` points to valid memory.
483             unsafe {
484                 copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count);
485             }
486             rem = &mut rem[count..];
487             total += count;
488         }
489 
490         self.regions.consume(total);
491         Ok(total)
492     }
493 }
494 
495 /// Provides high-level interface over the sequence of memory regions
496 /// defined by writable descriptors in the descriptor chain.
497 ///
498 /// Note that virtio spec requires driver to place any device-writable
499 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
500 /// Writer will start iterating the descriptors from the first writable one and will
501 /// assume that all following descriptors are writable.
502 #[derive(Clone)]
503 pub struct Writer {
504     mem: GuestMemory,
505     regions: DescriptorChainRegions,
506 }
507 
508 impl Writer {
509     /// Construct a new Writer wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer>510     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer> {
511         let regions = get_mem_regions(&mem, desc_chain.into_iter().writable())?;
512         Ok(Writer {
513             mem,
514             regions: DescriptorChainRegions {
515                 regions,
516                 current: 0,
517                 bytes_consumed: 0,
518             },
519         })
520     }
521 
522     /// Writes an object to the descriptor chain buffer.
write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()>523     pub fn write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()> {
524         self.write_all(val.as_slice())
525     }
526 
527     /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
528     /// this doesn't require the values to be stored in an intermediate collection first. It also
529     /// allows callers to choose which elements in a collection to write, for example by using the
530     /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: DataInit, I: Iterator<Item = T>>( &mut self, mut iter: I, ) -> io::Result<()>531     pub fn write_iter<T: DataInit, I: Iterator<Item = T>>(
532         &mut self,
533         mut iter: I,
534     ) -> io::Result<()> {
535         iter.try_for_each(|v| self.write_obj(v))
536     }
537 
538     /// Writes a collection of objects into the descriptor chain buffer.
consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>539     pub fn consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
540         self.write_iter(vals.into_iter())
541     }
542 
543     /// Returns number of bytes available for writing.  May return an error if the combined
544     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize545     pub fn available_bytes(&self) -> usize {
546         self.regions.available_bytes()
547     }
548 
549     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
550     /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize551     pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
552         let mut written = 0usize;
553         let mut src = slice;
554         for dst in self.get_remaining() {
555             src.copy_to_volatile_slice(dst);
556             let copied = std::cmp::min(src.size(), dst.size());
557             written += copied;
558             src = match src.offset(copied) {
559                 Ok(v) => v,
560                 Err(_) => break, // The slice is fully consumed
561             };
562         }
563         self.regions.consume(written);
564         written
565     }
566 
567     /// Writes data to the descriptor chain buffer from a file descriptor.
568     /// Returns the number of bytes written to the descriptor chain buffer.
569     /// The number of bytes written can be less than `count` if
570     /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>571     pub fn write_from<F: FileReadWriteVolatile>(
572         &mut self,
573         mut src: F,
574         count: usize,
575     ) -> io::Result<usize> {
576         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
577         let read = src.read_vectored_volatile(&iovs[..])?;
578         self.regions.consume(read);
579         Ok(read)
580     }
581 
582     /// Writes data to the descriptor chain buffer from a File at offset `off`.
583     /// Returns the number of bytes written to the descriptor chain buffer.
584     /// The number of bytes written can be less than `count` if
585     /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, count: usize, off: u64, ) -> io::Result<usize>586     pub fn write_from_at<F: FileReadWriteAtVolatile>(
587         &mut self,
588         mut src: F,
589         count: usize,
590         off: u64,
591     ) -> io::Result<usize> {
592         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
593         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
594         self.regions.consume(read);
595         Ok(read)
596     }
597 
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>598     pub fn write_all_from<F: FileReadWriteVolatile>(
599         &mut self,
600         mut src: F,
601         mut count: usize,
602     ) -> io::Result<()> {
603         while count > 0 {
604             match self.write_from(&mut src, count) {
605                 Ok(0) => {
606                     return Err(io::Error::new(
607                         io::ErrorKind::WriteZero,
608                         "failed to write whole buffer",
609                     ))
610                 }
611                 Ok(n) => count -= n,
612                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
613                 Err(e) => return Err(e),
614             }
615         }
616 
617         Ok(())
618     }
619 
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, mut count: usize, mut off: u64, ) -> io::Result<()>620     pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
621         &mut self,
622         mut src: F,
623         mut count: usize,
624         mut off: u64,
625     ) -> io::Result<()> {
626         while count > 0 {
627             match self.write_from_at(&mut src, count, off) {
628                 Ok(0) => {
629                     return Err(io::Error::new(
630                         io::ErrorKind::WriteZero,
631                         "failed to write whole buffer",
632                     ))
633                 }
634                 Ok(n) => {
635                     count -= n;
636                     off += n as u64;
637                 }
638                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
639                 Err(e) => return Err(e),
640             }
641         }
642         Ok(())
643     }
644     /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
645     /// Returns the number of bytes written to the descriptor chain buffer.
646     /// The number of bytes written can be less than `count` if
647     /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>648     pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
649         &mut self,
650         src: &F,
651         count: usize,
652         off: u64,
653     ) -> disk::Result<usize> {
654         let regions = self.regions.get_remaining_regions_with_count(count);
655         let read = src
656             .read_to_mem(off, Arc::new(self.mem.clone()), &regions)
657             .await?;
658         self.regions.consume(read);
659         Ok(read)
660     }
661 
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>662     pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
663         &mut self,
664         src: &F,
665         mut count: usize,
666         mut off: u64,
667     ) -> disk::Result<()> {
668         while count > 0 {
669             let nwritten = self.write_from_at_fut(src, count, off).await?;
670             if nwritten == 0 {
671                 return Err(disk::Error::WritingData(io::Error::new(
672                     io::ErrorKind::UnexpectedEof,
673                     "failed to write whole buffer",
674                 )));
675             }
676             count -= nwritten;
677             off += nwritten as u64;
678         }
679         Ok(())
680     }
681 
682     /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize683     pub fn bytes_written(&self) -> usize {
684         self.regions.bytes_consumed()
685     }
686 
687     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
688     /// Calling this method does not actually advance the current position of the `Writer` in the
689     /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
690     /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
691     /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>692     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
693         self.regions.get_remaining(&self.mem)
694     }
695 
696     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
697     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)698     pub fn consume_bytes(&mut self, amt: usize) {
699         self.regions.consume(amt)
700     }
701 
702     /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
703     /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
704     /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
705     /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer706     pub fn split_at(&mut self, offset: usize) -> Writer {
707         Writer {
708             mem: self.mem.clone(),
709             regions: self.regions.split_at(offset),
710         }
711     }
712 }
713 
714 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>715     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
716         let mut rem = buf;
717         let mut total = 0;
718         for b in self.regions.get_remaining(&self.mem) {
719             if rem.is_empty() {
720                 break;
721             }
722 
723             let count = cmp::min(rem.len(), b.size());
724             // Safe because we have already verified that `vs` points to valid memory.
725             unsafe {
726                 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
727             }
728             rem = &rem[count..];
729             total += count;
730         }
731 
732         self.regions.consume(total);
733         Ok(total)
734     }
735 
flush(&mut self) -> io::Result<()>736     fn flush(&mut self) -> io::Result<()> {
737         // Nothing to flush since the writes go straight into the buffer.
738         Ok(())
739     }
740 }
741 
742 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
743 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
744 
745 #[derive(Copy, Clone, PartialEq, Eq)]
746 pub enum DescriptorType {
747     Readable,
748     Writable,
749 }
750 
751 #[derive(Copy, Clone, Debug)]
752 #[repr(C)]
753 struct virtq_desc {
754     addr: Le64,
755     len: Le32,
756     flags: Le16,
757     next: Le16,
758 }
759 
760 // Safe because it only has data and has no implicit padding.
761 unsafe impl DataInit for virtq_desc {}
762 
763 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> Result<DescriptorChain>764 pub fn create_descriptor_chain(
765     memory: &GuestMemory,
766     descriptor_array_addr: GuestAddress,
767     mut buffers_start_addr: GuestAddress,
768     descriptors: Vec<(DescriptorType, u32)>,
769     spaces_between_regions: u32,
770 ) -> Result<DescriptorChain> {
771     let descriptors_len = descriptors.len();
772     for (index, (type_, size)) in descriptors.into_iter().enumerate() {
773         let mut flags = 0;
774         if let DescriptorType::Writable = type_ {
775             flags |= VIRTQ_DESC_F_WRITE;
776         }
777         if index + 1 < descriptors_len {
778             flags |= VIRTQ_DESC_F_NEXT;
779         }
780 
781         let index = index as u16;
782         let desc = virtq_desc {
783             addr: buffers_start_addr.offset().into(),
784             len: size.into(),
785             flags: flags.into(),
786             next: (index + 1).into(),
787         };
788 
789         let offset = size + spaces_between_regions;
790         buffers_start_addr = buffers_start_addr
791             .checked_add(offset as u64)
792             .context("Invalid buffers_start_addr)")
793             .map_err(Error::InvalidChain)?;
794 
795         let _ = memory.write_obj_at_addr(
796             desc,
797             descriptor_array_addr
798                 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
799                 .context("Invalid descriptor_array_addr")
800                 .map_err(Error::InvalidChain)?,
801         );
802     }
803 
804     DescriptorChain::checked_new(memory, descriptor_array_addr, 0x100, 0, 0, None)
805         .map_err(Error::InvalidChain)
806 }
807 
808 #[cfg(test)]
809 mod tests {
810     use super::*;
811     use std::fs::File;
812     use std::io::Read;
813     use tempfile::tempfile;
814 
815     use cros_async::Executor;
816 
817     #[test]
reader_test_simple_chain()818     fn reader_test_simple_chain() {
819         use DescriptorType::*;
820 
821         let memory_start_addr = GuestAddress(0x0);
822         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
823 
824         let chain = create_descriptor_chain(
825             &memory,
826             GuestAddress(0x0),
827             GuestAddress(0x100),
828             vec![
829                 (Readable, 8),
830                 (Readable, 16),
831                 (Readable, 18),
832                 (Readable, 64),
833             ],
834             0,
835         )
836         .expect("create_descriptor_chain failed");
837         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
838         assert_eq!(reader.available_bytes(), 106);
839         assert_eq!(reader.bytes_read(), 0);
840 
841         let mut buffer = [0u8; 64];
842         reader
843             .read_exact(&mut buffer)
844             .expect("read_exact should not fail here");
845 
846         assert_eq!(reader.available_bytes(), 42);
847         assert_eq!(reader.bytes_read(), 64);
848 
849         match reader.read(&mut buffer) {
850             Err(_) => panic!("read should not fail here"),
851             Ok(length) => assert_eq!(length, 42),
852         }
853 
854         assert_eq!(reader.available_bytes(), 0);
855         assert_eq!(reader.bytes_read(), 106);
856     }
857 
858     #[test]
writer_test_simple_chain()859     fn writer_test_simple_chain() {
860         use DescriptorType::*;
861 
862         let memory_start_addr = GuestAddress(0x0);
863         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
864 
865         let chain = create_descriptor_chain(
866             &memory,
867             GuestAddress(0x0),
868             GuestAddress(0x100),
869             vec![
870                 (Writable, 8),
871                 (Writable, 16),
872                 (Writable, 18),
873                 (Writable, 64),
874             ],
875             0,
876         )
877         .expect("create_descriptor_chain failed");
878         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
879         assert_eq!(writer.available_bytes(), 106);
880         assert_eq!(writer.bytes_written(), 0);
881 
882         let buffer = [0; 64];
883         writer
884             .write_all(&buffer)
885             .expect("write_all should not fail here");
886 
887         assert_eq!(writer.available_bytes(), 42);
888         assert_eq!(writer.bytes_written(), 64);
889 
890         match writer.write(&buffer) {
891             Err(_) => panic!("write should not fail here"),
892             Ok(length) => assert_eq!(length, 42),
893         }
894 
895         assert_eq!(writer.available_bytes(), 0);
896         assert_eq!(writer.bytes_written(), 106);
897     }
898 
899     #[test]
reader_test_incompatible_chain()900     fn reader_test_incompatible_chain() {
901         use DescriptorType::*;
902 
903         let memory_start_addr = GuestAddress(0x0);
904         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
905 
906         let chain = create_descriptor_chain(
907             &memory,
908             GuestAddress(0x0),
909             GuestAddress(0x100),
910             vec![(Writable, 8)],
911             0,
912         )
913         .expect("create_descriptor_chain failed");
914         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
915         assert_eq!(reader.available_bytes(), 0);
916         assert_eq!(reader.bytes_read(), 0);
917 
918         assert!(reader.read_obj::<u8>().is_err());
919 
920         assert_eq!(reader.available_bytes(), 0);
921         assert_eq!(reader.bytes_read(), 0);
922     }
923 
924     #[test]
writer_test_incompatible_chain()925     fn writer_test_incompatible_chain() {
926         use DescriptorType::*;
927 
928         let memory_start_addr = GuestAddress(0x0);
929         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
930 
931         let chain = create_descriptor_chain(
932             &memory,
933             GuestAddress(0x0),
934             GuestAddress(0x100),
935             vec![(Readable, 8)],
936             0,
937         )
938         .expect("create_descriptor_chain failed");
939         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
940         assert_eq!(writer.available_bytes(), 0);
941         assert_eq!(writer.bytes_written(), 0);
942 
943         assert!(writer.write_obj(0u8).is_err());
944 
945         assert_eq!(writer.available_bytes(), 0);
946         assert_eq!(writer.bytes_written(), 0);
947     }
948 
949     #[test]
reader_failing_io()950     fn reader_failing_io() {
951         use DescriptorType::*;
952 
953         let memory_start_addr = GuestAddress(0x0);
954         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
955 
956         let chain = create_descriptor_chain(
957             &memory,
958             GuestAddress(0x0),
959             GuestAddress(0x100),
960             vec![(Readable, 256), (Readable, 256)],
961             0,
962         )
963         .expect("create_descriptor_chain failed");
964 
965         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
966 
967         // Open a file in read-only mode so writes to it to trigger an I/O error.
968         let mut ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
969 
970         reader
971             .read_exact_to(&mut ro_file, 512)
972             .expect_err("successfully read more bytes than SharedMemory size");
973 
974         // The write above should have failed entirely, so we end up not writing any bytes at all.
975         assert_eq!(reader.available_bytes(), 512);
976         assert_eq!(reader.bytes_read(), 0);
977     }
978 
979     #[test]
writer_failing_io()980     fn writer_failing_io() {
981         use DescriptorType::*;
982 
983         let memory_start_addr = GuestAddress(0x0);
984         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
985 
986         let chain = create_descriptor_chain(
987             &memory,
988             GuestAddress(0x0),
989             GuestAddress(0x100),
990             vec![(Writable, 256), (Writable, 256)],
991             0,
992         )
993         .expect("create_descriptor_chain failed");
994 
995         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
996 
997         let mut file = tempfile().unwrap();
998 
999         file.set_len(384).unwrap();
1000 
1001         writer
1002             .write_all_from(&mut file, 512)
1003             .expect_err("successfully wrote more bytes than in SharedMemory");
1004 
1005         assert_eq!(writer.available_bytes(), 128);
1006         assert_eq!(writer.bytes_written(), 384);
1007     }
1008 
1009     #[test]
reader_writer_shared_chain()1010     fn reader_writer_shared_chain() {
1011         use DescriptorType::*;
1012 
1013         let memory_start_addr = GuestAddress(0x0);
1014         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1015 
1016         let chain = create_descriptor_chain(
1017             &memory,
1018             GuestAddress(0x0),
1019             GuestAddress(0x100),
1020             vec![
1021                 (Readable, 16),
1022                 (Readable, 16),
1023                 (Readable, 96),
1024                 (Writable, 64),
1025                 (Writable, 1),
1026                 (Writable, 3),
1027             ],
1028             0,
1029         )
1030         .expect("create_descriptor_chain failed");
1031         let mut reader =
1032             Reader::new(memory.clone(), chain.clone()).expect("failed to create Reader");
1033         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
1034 
1035         assert_eq!(reader.bytes_read(), 0);
1036         assert_eq!(writer.bytes_written(), 0);
1037 
1038         let mut buffer = Vec::with_capacity(200);
1039 
1040         assert_eq!(
1041             reader
1042                 .read_to_end(&mut buffer)
1043                 .expect("read should not fail here"),
1044             128
1045         );
1046 
1047         // The writable descriptors are only 68 bytes long.
1048         writer
1049             .write_all(&buffer[..68])
1050             .expect("write should not fail here");
1051 
1052         assert_eq!(reader.available_bytes(), 0);
1053         assert_eq!(reader.bytes_read(), 128);
1054         assert_eq!(writer.available_bytes(), 0);
1055         assert_eq!(writer.bytes_written(), 68);
1056     }
1057 
1058     #[test]
reader_writer_shattered_object()1059     fn reader_writer_shattered_object() {
1060         use DescriptorType::*;
1061 
1062         let memory_start_addr = GuestAddress(0x0);
1063         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1064 
1065         let secret: Le32 = 0x12345678.into();
1066 
1067         // Create a descriptor chain with memory regions that are properly separated.
1068         let chain_writer = create_descriptor_chain(
1069             &memory,
1070             GuestAddress(0x0),
1071             GuestAddress(0x100),
1072             vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1073             123,
1074         )
1075         .expect("create_descriptor_chain failed");
1076         let mut writer =
1077             Writer::new(memory.clone(), chain_writer).expect("failed to create Writer");
1078         writer
1079             .write_obj(secret)
1080             .expect("write_obj should not fail here");
1081 
1082         // Now create new descriptor chain pointing to the same memory and try to read it.
1083         let chain_reader = create_descriptor_chain(
1084             &memory,
1085             GuestAddress(0x0),
1086             GuestAddress(0x100),
1087             vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1088             123,
1089         )
1090         .expect("create_descriptor_chain failed");
1091         let mut reader = Reader::new(memory, chain_reader).expect("failed to create Reader");
1092         match reader.read_obj::<Le32>() {
1093             Err(_) => panic!("read_obj should not fail here"),
1094             Ok(read_secret) => assert_eq!(read_secret, secret),
1095         }
1096     }
1097 
1098     #[test]
reader_unexpected_eof()1099     fn reader_unexpected_eof() {
1100         use DescriptorType::*;
1101 
1102         let memory_start_addr = GuestAddress(0x0);
1103         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1104 
1105         let chain = create_descriptor_chain(
1106             &memory,
1107             GuestAddress(0x0),
1108             GuestAddress(0x100),
1109             vec![(Readable, 256), (Readable, 256)],
1110             0,
1111         )
1112         .expect("create_descriptor_chain failed");
1113 
1114         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1115 
1116         let mut buf = vec![0; 1024];
1117 
1118         assert_eq!(
1119             reader
1120                 .read_exact(&mut buf[..])
1121                 .expect_err("read more bytes than available")
1122                 .kind(),
1123             io::ErrorKind::UnexpectedEof
1124         );
1125     }
1126 
1127     #[test]
split_border()1128     fn split_border() {
1129         use DescriptorType::*;
1130 
1131         let memory_start_addr = GuestAddress(0x0);
1132         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1133 
1134         let chain = create_descriptor_chain(
1135             &memory,
1136             GuestAddress(0x0),
1137             GuestAddress(0x100),
1138             vec![
1139                 (Readable, 16),
1140                 (Readable, 16),
1141                 (Readable, 96),
1142                 (Writable, 64),
1143                 (Writable, 1),
1144                 (Writable, 3),
1145             ],
1146             0,
1147         )
1148         .expect("create_descriptor_chain failed");
1149         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1150 
1151         let other = reader.split_at(32);
1152         assert_eq!(reader.available_bytes(), 32);
1153         assert_eq!(other.available_bytes(), 96);
1154     }
1155 
1156     #[test]
split_middle()1157     fn split_middle() {
1158         use DescriptorType::*;
1159 
1160         let memory_start_addr = GuestAddress(0x0);
1161         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1162 
1163         let chain = create_descriptor_chain(
1164             &memory,
1165             GuestAddress(0x0),
1166             GuestAddress(0x100),
1167             vec![
1168                 (Readable, 16),
1169                 (Readable, 16),
1170                 (Readable, 96),
1171                 (Writable, 64),
1172                 (Writable, 1),
1173                 (Writable, 3),
1174             ],
1175             0,
1176         )
1177         .expect("create_descriptor_chain failed");
1178         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1179 
1180         let other = reader.split_at(24);
1181         assert_eq!(reader.available_bytes(), 24);
1182         assert_eq!(other.available_bytes(), 104);
1183     }
1184 
1185     #[test]
split_end()1186     fn split_end() {
1187         use DescriptorType::*;
1188 
1189         let memory_start_addr = GuestAddress(0x0);
1190         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1191 
1192         let chain = create_descriptor_chain(
1193             &memory,
1194             GuestAddress(0x0),
1195             GuestAddress(0x100),
1196             vec![
1197                 (Readable, 16),
1198                 (Readable, 16),
1199                 (Readable, 96),
1200                 (Writable, 64),
1201                 (Writable, 1),
1202                 (Writable, 3),
1203             ],
1204             0,
1205         )
1206         .expect("create_descriptor_chain failed");
1207         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1208 
1209         let other = reader.split_at(128);
1210         assert_eq!(reader.available_bytes(), 128);
1211         assert_eq!(other.available_bytes(), 0);
1212     }
1213 
1214     #[test]
split_beginning()1215     fn split_beginning() {
1216         use DescriptorType::*;
1217 
1218         let memory_start_addr = GuestAddress(0x0);
1219         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1220 
1221         let chain = create_descriptor_chain(
1222             &memory,
1223             GuestAddress(0x0),
1224             GuestAddress(0x100),
1225             vec![
1226                 (Readable, 16),
1227                 (Readable, 16),
1228                 (Readable, 96),
1229                 (Writable, 64),
1230                 (Writable, 1),
1231                 (Writable, 3),
1232             ],
1233             0,
1234         )
1235         .expect("create_descriptor_chain failed");
1236         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1237 
1238         let other = reader.split_at(0);
1239         assert_eq!(reader.available_bytes(), 0);
1240         assert_eq!(other.available_bytes(), 128);
1241     }
1242 
1243     #[test]
split_outofbounds()1244     fn split_outofbounds() {
1245         use DescriptorType::*;
1246 
1247         let memory_start_addr = GuestAddress(0x0);
1248         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1249 
1250         let chain = create_descriptor_chain(
1251             &memory,
1252             GuestAddress(0x0),
1253             GuestAddress(0x100),
1254             vec![
1255                 (Readable, 16),
1256                 (Readable, 16),
1257                 (Readable, 96),
1258                 (Writable, 64),
1259                 (Writable, 1),
1260                 (Writable, 3),
1261             ],
1262             0,
1263         )
1264         .expect("create_descriptor_chain failed");
1265         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1266 
1267         let other = reader.split_at(256);
1268         assert_eq!(
1269             other.available_bytes(),
1270             0,
1271             "Reader returned from out-of-bounds split still has available bytes"
1272         );
1273     }
1274 
1275     #[test]
read_full()1276     fn read_full() {
1277         use DescriptorType::*;
1278 
1279         let memory_start_addr = GuestAddress(0x0);
1280         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1281 
1282         let chain = create_descriptor_chain(
1283             &memory,
1284             GuestAddress(0x0),
1285             GuestAddress(0x100),
1286             vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1287             0,
1288         )
1289         .expect("create_descriptor_chain failed");
1290         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1291 
1292         let mut buf = vec![0u8; 64];
1293         assert_eq!(
1294             reader.read(&mut buf[..]).expect("failed to read to buffer"),
1295             48
1296         );
1297     }
1298 
1299     #[test]
write_full()1300     fn write_full() {
1301         use DescriptorType::*;
1302 
1303         let memory_start_addr = GuestAddress(0x0);
1304         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1305 
1306         let chain = create_descriptor_chain(
1307             &memory,
1308             GuestAddress(0x0),
1309             GuestAddress(0x100),
1310             vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1311             0,
1312         )
1313         .expect("create_descriptor_chain failed");
1314         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
1315 
1316         let buf = vec![0xdeu8; 64];
1317         assert_eq!(
1318             writer.write(&buf[..]).expect("failed to write from buffer"),
1319             48
1320         );
1321     }
1322 
1323     #[test]
consume_collect()1324     fn consume_collect() {
1325         use DescriptorType::*;
1326 
1327         let memory_start_addr = GuestAddress(0x0);
1328         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1329         let vs: Vec<Le64> = vec![
1330             0x0101010101010101.into(),
1331             0x0202020202020202.into(),
1332             0x0303030303030303.into(),
1333         ];
1334 
1335         let write_chain = create_descriptor_chain(
1336             &memory,
1337             GuestAddress(0x0),
1338             GuestAddress(0x100),
1339             vec![(Writable, 24)],
1340             0,
1341         )
1342         .expect("create_descriptor_chain failed");
1343         let mut writer = Writer::new(memory.clone(), write_chain).expect("failed to create Writer");
1344         writer
1345             .consume(vs.clone())
1346             .expect("failed to consume() a vector");
1347 
1348         let read_chain = create_descriptor_chain(
1349             &memory,
1350             GuestAddress(0x0),
1351             GuestAddress(0x100),
1352             vec![(Readable, 24)],
1353             0,
1354         )
1355         .expect("create_descriptor_chain failed");
1356         let mut reader = Reader::new(memory, read_chain).expect("failed to create Reader");
1357         let vs_read = reader
1358             .collect::<io::Result<Vec<Le64>>, _>()
1359             .expect("failed to collect() values");
1360         assert_eq!(vs, vs_read);
1361     }
1362 
1363     #[test]
get_remaining_region_with_count()1364     fn get_remaining_region_with_count() {
1365         use DescriptorType::*;
1366 
1367         let memory_start_addr = GuestAddress(0x0);
1368         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1369 
1370         let chain = create_descriptor_chain(
1371             &memory,
1372             GuestAddress(0x0),
1373             GuestAddress(0x100),
1374             vec![
1375                 (Readable, 16),
1376                 (Readable, 16),
1377                 (Readable, 96),
1378                 (Writable, 64),
1379                 (Writable, 1),
1380                 (Writable, 3),
1381             ],
1382             0,
1383         )
1384         .expect("create_descriptor_chain failed");
1385 
1386         let Reader {
1387             mem: _,
1388             mut regions,
1389         } = Reader::new(memory, chain).expect("failed to create Reader");
1390 
1391         let drain = regions
1392             .get_remaining_regions_with_count(::std::usize::MAX)
1393             .iter()
1394             .fold(0usize, |total, region| total + region.len);
1395         assert_eq!(drain, 128);
1396 
1397         let exact = regions
1398             .get_remaining_regions_with_count(32)
1399             .iter()
1400             .fold(0usize, |total, region| total + region.len);
1401         assert!(exact > 0);
1402         assert!(exact <= 32);
1403 
1404         let split = regions
1405             .get_remaining_regions_with_count(24)
1406             .iter()
1407             .fold(0usize, |total, region| total + region.len);
1408         assert!(split > 0);
1409         assert!(split <= 24);
1410 
1411         regions.consume(64);
1412 
1413         let first = regions
1414             .get_remaining_regions_with_count(8)
1415             .iter()
1416             .fold(0usize, |total, region| total + region.len);
1417         assert!(first > 0);
1418         assert!(first <= 8);
1419     }
1420 
1421     #[test]
get_remaining_with_count()1422     fn get_remaining_with_count() {
1423         use DescriptorType::*;
1424 
1425         let memory_start_addr = GuestAddress(0x0);
1426         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1427 
1428         let chain = create_descriptor_chain(
1429             &memory,
1430             GuestAddress(0x0),
1431             GuestAddress(0x100),
1432             vec![
1433                 (Readable, 16),
1434                 (Readable, 16),
1435                 (Readable, 96),
1436                 (Writable, 64),
1437                 (Writable, 1),
1438                 (Writable, 3),
1439             ],
1440             0,
1441         )
1442         .expect("create_descriptor_chain failed");
1443         let Reader {
1444             mem: _,
1445             mut regions,
1446         } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1447 
1448         let drain = regions
1449             .get_remaining_with_count(&memory, ::std::usize::MAX)
1450             .iter()
1451             .fold(0usize, |total, iov| total + iov.size());
1452         assert_eq!(drain, 128);
1453 
1454         let exact = regions
1455             .get_remaining_with_count(&memory, 32)
1456             .iter()
1457             .fold(0usize, |total, iov| total + iov.size());
1458         assert!(exact > 0);
1459         assert!(exact <= 32);
1460 
1461         let split = regions
1462             .get_remaining_with_count(&memory, 24)
1463             .iter()
1464             .fold(0usize, |total, iov| total + iov.size());
1465         assert!(split > 0);
1466         assert!(split <= 24);
1467 
1468         regions.consume(64);
1469 
1470         let first = regions
1471             .get_remaining_with_count(&memory, 8)
1472             .iter()
1473             .fold(0usize, |total, iov| total + iov.size());
1474         assert!(first > 0);
1475         assert!(first <= 8);
1476     }
1477 
1478     #[test]
region_reader_failing_io()1479     fn region_reader_failing_io() {
1480         let ex = Executor::new().unwrap();
1481         ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1482     }
region_reader_failing_io_async(ex: &Executor)1483     async fn region_reader_failing_io_async(ex: &Executor) {
1484         use DescriptorType::*;
1485 
1486         let memory_start_addr = GuestAddress(0x0);
1487         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1488 
1489         let chain = create_descriptor_chain(
1490             &memory,
1491             GuestAddress(0x0),
1492             GuestAddress(0x100),
1493             vec![(Readable, 256), (Readable, 256)],
1494             0,
1495         )
1496         .expect("create_descriptor_chain failed");
1497 
1498         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1499 
1500         // Open a file in read-only mode so writes to it to trigger an I/O error.
1501         let ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
1502         let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1503 
1504         reader
1505             .read_exact_to_at_fut(&async_ro_file, 512, 0)
1506             .await
1507             .expect_err("successfully read more bytes than SharedMemory size");
1508 
1509         // The write above should have failed entirely, so we end up not writing any bytes at all.
1510         assert_eq!(reader.available_bytes(), 512);
1511         assert_eq!(reader.bytes_read(), 0);
1512     }
1513 
1514     #[test]
region_writer_failing_io()1515     fn region_writer_failing_io() {
1516         let ex = Executor::new().unwrap();
1517         ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1518     }
region_writer_failing_io_async(ex: &Executor)1519     async fn region_writer_failing_io_async(ex: &Executor) {
1520         use DescriptorType::*;
1521 
1522         let memory_start_addr = GuestAddress(0x0);
1523         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1524 
1525         let chain = create_descriptor_chain(
1526             &memory,
1527             GuestAddress(0x0),
1528             GuestAddress(0x100),
1529             vec![(Writable, 256), (Writable, 256)],
1530             0,
1531         )
1532         .expect("create_descriptor_chain failed");
1533 
1534         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1535 
1536         let file = tempfile().expect("failed to create temp file");
1537 
1538         file.set_len(384).unwrap();
1539         let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1540 
1541         writer
1542             .write_all_from_at_fut(&async_file, 512, 0)
1543             .await
1544             .expect_err("successfully wrote more bytes than in SharedMemory");
1545 
1546         assert_eq!(writer.available_bytes(), 128);
1547         assert_eq!(writer.bytes_written(), 384);
1548     }
1549 }
1550