• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::borrow::Cow;
6 use std::cmp;
7 use std::convert::TryInto;
8 use std::fmt::{self, Display};
9 use std::io::{self, Read, Write};
10 use std::iter::FromIterator;
11 use std::marker::PhantomData;
12 use std::mem::{size_of, MaybeUninit};
13 use std::ptr::copy_nonoverlapping;
14 use std::result;
15 use std::sync::Arc;
16 
17 use base::{FileReadWriteAtVolatile, FileReadWriteVolatile};
18 use cros_async::MemRegion;
19 use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice};
20 use disk::AsyncDisk;
21 use smallvec::SmallVec;
22 use vm_memory::{GuestAddress, GuestMemory};
23 
24 use super::DescriptorChain;
25 
26 #[derive(Debug)]
27 pub enum Error {
28     DescriptorChainOverflow,
29     GuestMemoryError(vm_memory::GuestMemoryError),
30     InvalidChain,
31     IoError(io::Error),
32     SplitOutOfBounds(usize),
33     VolatileMemoryError(VolatileMemoryError),
34 }
35 
36 impl Display for Error {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result37     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38         use self::Error::*;
39 
40         match self {
41             DescriptorChainOverflow => write!(
42                 f,
43                 "the combined length of all the buffers in a `DescriptorChain` would overflow"
44             ),
45             GuestMemoryError(e) => write!(f, "descriptor guest memory error: {}", e),
46             InvalidChain => write!(f, "invalid descriptor chain"),
47             IoError(e) => write!(f, "descriptor I/O error: {}", e),
48             SplitOutOfBounds(off) => write!(f, "`DescriptorChain` split is out of bounds: {}", off),
49             VolatileMemoryError(e) => write!(f, "volatile memory error: {}", e),
50         }
51     }
52 }
53 
54 pub type Result<T> = result::Result<T, Error>;
55 
56 impl std::error::Error for Error {}
57 
58 #[derive(Clone)]
59 struct DescriptorChainRegions {
60     regions: SmallVec<[MemRegion; 16]>,
61     current: usize,
62     bytes_consumed: usize,
63 }
64 
65 impl DescriptorChainRegions {
available_bytes(&self) -> usize66     fn available_bytes(&self) -> usize {
67         // This is guaranteed not to overflow because the total length of the chain
68         // is checked during all creations of `DescriptorChainRegions` (see
69         // `Reader::new()` and `Writer::new()`).
70         self.get_remaining_regions()
71             .iter()
72             .fold(0usize, |count, region| count + region.len)
73     }
74 
bytes_consumed(&self) -> usize75     fn bytes_consumed(&self) -> usize {
76         self.bytes_consumed
77     }
78 
79     /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
80     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
81     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
82     /// to `consume` will return the same data.
get_remaining_regions(&self) -> &[MemRegion]83     fn get_remaining_regions(&self) -> &[MemRegion] {
84         &self.regions[self.current..]
85     }
86 
87     /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
88     /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
89     /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
90     /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>91     fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
92         self.get_remaining_regions()
93             .iter()
94             .filter_map(|region| {
95                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
96                     .ok()
97             })
98             .collect()
99     }
100 
101     /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is
102     /// not greater than `count`. The combined length of the returned iovecs may be less than
103     /// `count` but will always be greater than 0 as long as there is still space left in the
104     /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]>105     fn get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]> {
106         let regions = self.get_remaining_regions();
107         let mut region_count = 0;
108         let mut rem = count;
109         for region in regions {
110             if rem < region.len {
111                 break;
112             }
113 
114             region_count += 1;
115             rem -= region.len;
116         }
117 
118         // Special case where the number of bytes to be copied is smaller than the `size()` of the
119         // first regions.
120         if region_count == 0 && !regions.is_empty() && count > 0 {
121             debug_assert!(count < regions[0].len);
122             // Safe because we know that count is smaller than the length of the first slice.
123             Cow::Owned(vec![MemRegion {
124                 offset: regions[0].offset,
125                 len: count,
126             }])
127         } else {
128             Cow::Borrowed(&regions[..region_count])
129         }
130     }
131 
132     /// Like 'get_remaining_with_count' except convert the offsets to volatile slices in the
133     /// 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>134     fn get_remaining_with_count<'mem>(
135         &self,
136         mem: &'mem GuestMemory,
137         count: usize,
138     ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
139         self.get_remaining_regions_with_count(count)
140             .iter()
141             .filter_map(|region| {
142                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
143                     .ok()
144             })
145             .collect()
146     }
147 
148     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
149     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)150     fn consume(&mut self, mut count: usize) {
151         // The implementation is adapted from `IoSlice::advance` in libstd. We can't use
152         // `get_remaining` here because then the compiler complains that `self.current` is already
153         // borrowed and doesn't allow us to modify it.  We also need to borrow the iovecs mutably.
154         let current = self.current;
155         for region in &mut self.regions[current..] {
156             if count == 0 {
157                 break;
158             }
159 
160             let consumed = if count < region.len {
161                 // Safe because we know that the iovec pointed to valid memory and we are adding a
162                 // value that is smaller than the length of the memory.
163                 *region = MemRegion {
164                     offset: region.offset + count as u64,
165                     len: region.len - count,
166                 };
167                 count
168             } else {
169                 self.current += 1;
170                 region.len
171             };
172 
173             // This shouldn't overflow because `consumed <= buf.size()` and we already verified
174             // that adding all `buf.size()` values will not overflow when the Reader/Writer was
175             // constructed.
176             self.bytes_consumed += consumed;
177             count -= consumed;
178         }
179     }
180 
split_at(&mut self, offset: usize) -> DescriptorChainRegions181     fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
182         let mut other = self.clone();
183         other.consume(offset);
184         other.bytes_consumed = 0;
185 
186         let mut rem = offset;
187         let mut end = self.current;
188         for region in &mut self.regions[self.current..] {
189             if rem <= region.len {
190                 region.len = rem;
191                 break;
192             }
193 
194             end += 1;
195             rem -= region.len;
196         }
197 
198         self.regions.truncate(end + 1);
199 
200         other
201     }
202 }
203 
204 /// Provides high-level interface over the sequence of memory regions
205 /// defined by readable descriptors in the descriptor chain.
206 ///
207 /// Note that virtio spec requires driver to place any device-writable
208 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
209 /// Reader will skip iterating over descriptor chain when first writable
210 /// descriptor is encountered.
211 #[derive(Clone)]
212 pub struct Reader {
213     mem: GuestMemory,
214     regions: DescriptorChainRegions,
215 }
216 
217 /// An iterator over `DataInit` objects on readable descriptors in the descriptor chain.
218 pub struct ReaderIterator<'a, T: DataInit> {
219     reader: &'a mut Reader,
220     phantom: PhantomData<T>,
221 }
222 
223 impl<'a, T: DataInit> Iterator for ReaderIterator<'a, T> {
224     type Item = io::Result<T>;
225 
next(&mut self) -> Option<io::Result<T>>226     fn next(&mut self) -> Option<io::Result<T>> {
227         if self.reader.available_bytes() == 0 {
228             None
229         } else {
230             Some(self.reader.read_obj())
231         }
232     }
233 }
234 
235 impl Reader {
236     /// Construct a new Reader wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader>237     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader> {
238         // TODO(jstaron): Update this code to take the indirect descriptors into account.
239         let mut total_len: usize = 0;
240         let regions = desc_chain
241             .into_iter()
242             .readable()
243             .map(|desc| {
244                 // Verify that summing the descriptor sizes does not overflow.
245                 // This can happen if a driver tricks a device into reading more data than
246                 // fits in a `usize`.
247                 total_len = total_len
248                     .checked_add(desc.len as usize)
249                     .ok_or(Error::DescriptorChainOverflow)?;
250 
251                 // Check that all the regions are totally contained in GuestMemory.
252                 mem.get_slice_at_addr(
253                     desc.addr,
254                     desc.len.try_into().expect("u32 doesn't fit in usize"),
255                 )
256                 .map_err(Error::GuestMemoryError)?;
257 
258                 Ok(MemRegion {
259                     offset: desc.addr.0,
260                     len: desc.len.try_into().expect("u32 doesn't fit in usize"),
261                 })
262             })
263             .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
264         Ok(Reader {
265             mem,
266             regions: DescriptorChainRegions {
267                 regions,
268                 current: 0,
269                 bytes_consumed: 0,
270             },
271         })
272     }
273 
274     /// Reads an object from the descriptor chain buffer.
read_obj<T: DataInit>(&mut self) -> io::Result<T>275     pub fn read_obj<T: DataInit>(&mut self) -> io::Result<T> {
276         let mut obj = MaybeUninit::<T>::uninit();
277 
278         // Safe because `MaybeUninit` guarantees that the pointer is valid for
279         // `size_of::<T>()` bytes.
280         let buf = unsafe {
281             ::std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>())
282         };
283 
284         self.read_exact(buf)?;
285 
286         // Safe because any type that implements `DataInit` can be considered initialized
287         // even if it is filled with random data.
288         Ok(unsafe { obj.assume_init() })
289     }
290 
291     /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
292     /// them as a collection. Returns an error if the size of the remaining data is indivisible by
293     /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C294     pub fn collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C {
295         self.iter().collect()
296     }
297 
298     /// Creates an iterator for sequentially reading `DataInit` objects from the `Reader`.
299     /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
300     /// doesn't require the objects to be stored in a separate collection.
iter<T: DataInit>(&mut self) -> ReaderIterator<T>301     pub fn iter<T: DataInit>(&mut self) -> ReaderIterator<T> {
302         ReaderIterator {
303             reader: self,
304             phantom: PhantomData,
305         }
306     }
307 
308     /// Reads data from the descriptor chain buffer into a file descriptor.
309     /// Returns the number of bytes read from the descriptor chain buffer.
310     /// The number of bytes read can be less than `count` if there isn't
311     /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>312     pub fn read_to<F: FileReadWriteVolatile>(
313         &mut self,
314         mut dst: F,
315         count: usize,
316     ) -> io::Result<usize> {
317         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
318         let written = dst.write_vectored_volatile(&iovs[..])?;
319         self.regions.consume(written);
320         Ok(written)
321     }
322 
323     /// Reads data from the descriptor chain buffer into a File at offset `off`.
324     /// Returns the number of bytes read from the descriptor chain buffer.
325     /// The number of bytes read can be less than `count` if there isn't
326     /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, count: usize, off: u64, ) -> io::Result<usize>327     pub fn read_to_at<F: FileReadWriteAtVolatile>(
328         &mut self,
329         mut dst: F,
330         count: usize,
331         off: u64,
332     ) -> io::Result<usize> {
333         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
334         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
335         self.regions.consume(written);
336         Ok(written)
337     }
338 
339     /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
340     /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>341     pub fn read_exact_to<F: FileReadWriteVolatile>(
342         &mut self,
343         mut dst: F,
344         mut count: usize,
345     ) -> io::Result<()> {
346         while count > 0 {
347             match self.read_to(&mut dst, count) {
348                 Ok(0) => {
349                     return Err(io::Error::new(
350                         io::ErrorKind::UnexpectedEof,
351                         "failed to fill whole buffer",
352                     ))
353                 }
354                 Ok(n) => count -= n,
355                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
356                 Err(e) => return Err(e),
357             }
358         }
359 
360         Ok(())
361     }
362 
363     /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
364     /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, mut count: usize, mut off: u64, ) -> io::Result<()>365     pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
366         &mut self,
367         mut dst: F,
368         mut count: usize,
369         mut off: u64,
370     ) -> io::Result<()> {
371         while count > 0 {
372             match self.read_to_at(&mut dst, count, off) {
373                 Ok(0) => {
374                     return Err(io::Error::new(
375                         io::ErrorKind::UnexpectedEof,
376                         "failed to fill whole buffer",
377                     ))
378                 }
379                 Ok(n) => {
380                     count -= n;
381                     off += n as u64;
382                 }
383                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
384                 Err(e) => return Err(e),
385             }
386         }
387 
388         Ok(())
389     }
390 
391     /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
392     /// Returns the number of bytes read from the descriptor chain buffer.
393     /// The number of bytes read can be less than `count` if there isn't
394     /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>395     pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
396         &mut self,
397         dst: &F,
398         count: usize,
399         off: u64,
400     ) -> disk::Result<usize> {
401         let mem_regions = self.regions.get_remaining_regions_with_count(count);
402         let written = dst
403             .write_from_mem(off, Arc::new(self.mem.clone()), &mem_regions)
404             .await?;
405         self.regions.consume(written);
406         Ok(written)
407     }
408 
409     /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
410     /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>411     pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
412         &mut self,
413         dst: &F,
414         mut count: usize,
415         mut off: u64,
416     ) -> disk::Result<()> {
417         while count > 0 {
418             let nread = self.read_to_at_fut(dst, count, off).await?;
419             if nread == 0 {
420                 return Err(disk::Error::ReadingData(io::Error::new(
421                     io::ErrorKind::UnexpectedEof,
422                     "failed to write whole buffer",
423                 )));
424             }
425             count -= nread;
426             off += nread as u64;
427         }
428 
429         Ok(())
430     }
431 
432     /// Returns number of bytes available for reading.  May return an error if the combined
433     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize434     pub fn available_bytes(&self) -> usize {
435         self.regions.available_bytes()
436     }
437 
438     /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize439     pub fn bytes_read(&self) -> usize {
440         self.regions.bytes_consumed()
441     }
442 
443     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
444     /// Calling this method does not actually consume any data from the `Reader` and callers should
445     /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>446     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
447         self.regions.get_remaining(&self.mem)
448     }
449 
450     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
451     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)452     pub fn consume(&mut self, amt: usize) {
453         self.regions.consume(amt)
454     }
455 
456     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
457     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
458     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
459     /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader460     pub fn split_at(&mut self, offset: usize) -> Reader {
461         Reader {
462             mem: self.mem.clone(),
463             regions: self.regions.split_at(offset),
464         }
465     }
466 }
467 
468 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>469     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
470         let mut rem = buf;
471         let mut total = 0;
472         for b in self.regions.get_remaining(&self.mem) {
473             if rem.is_empty() {
474                 break;
475             }
476 
477             let count = cmp::min(rem.len(), b.size());
478 
479             // Safe because we have already verified that `b` points to valid memory.
480             unsafe {
481                 copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count);
482             }
483             rem = &mut rem[count..];
484             total += count;
485         }
486 
487         self.regions.consume(total);
488         Ok(total)
489     }
490 }
491 
492 /// Provides high-level interface over the sequence of memory regions
493 /// defined by writable descriptors in the descriptor chain.
494 ///
495 /// Note that virtio spec requires driver to place any device-writable
496 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
497 /// Writer will start iterating the descriptors from the first writable one and will
498 /// assume that all following descriptors are writable.
499 #[derive(Clone)]
500 pub struct Writer {
501     mem: GuestMemory,
502     regions: DescriptorChainRegions,
503 }
504 
505 impl Writer {
506     /// Construct a new Writer wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer>507     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer> {
508         let mut total_len: usize = 0;
509         let regions = desc_chain
510             .into_iter()
511             .writable()
512             .map(|desc| {
513                 // Verify that summing the descriptor sizes does not overflow.
514                 // This can happen if a driver tricks a device into writing more data than
515                 // fits in a `usize`.
516                 total_len = total_len
517                     .checked_add(desc.len as usize)
518                     .ok_or(Error::DescriptorChainOverflow)?;
519 
520                 mem.get_slice_at_addr(
521                     desc.addr,
522                     desc.len.try_into().expect("u32 doesn't fit in usize"),
523                 )
524                 .map_err(Error::GuestMemoryError)?;
525 
526                 Ok(MemRegion {
527                     offset: desc.addr.0,
528                     len: desc.len.try_into().expect("u32 doesn't fit in usize"),
529                 })
530             })
531             .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
532         Ok(Writer {
533             mem,
534             regions: DescriptorChainRegions {
535                 regions,
536                 current: 0,
537                 bytes_consumed: 0,
538             },
539         })
540     }
541 
542     /// Writes an object to the descriptor chain buffer.
write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()>543     pub fn write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()> {
544         self.write_all(val.as_slice())
545     }
546 
547     /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
548     /// this doesn't require the values to be stored in an intermediate collection first. It also
549     /// allows callers to choose which elements in a collection to write, for example by using the
550     /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: DataInit, I: Iterator<Item = T>>( &mut self, mut iter: I, ) -> io::Result<()>551     pub fn write_iter<T: DataInit, I: Iterator<Item = T>>(
552         &mut self,
553         mut iter: I,
554     ) -> io::Result<()> {
555         iter.try_for_each(|v| self.write_obj(v))
556     }
557 
558     /// Writes a collection of objects into the descriptor chain buffer.
consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>559     pub fn consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
560         self.write_iter(vals.into_iter())
561     }
562 
563     /// Returns number of bytes available for writing.  May return an error if the combined
564     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize565     pub fn available_bytes(&self) -> usize {
566         self.regions.available_bytes()
567     }
568 
569     /// Writes data to the descriptor chain buffer from a file descriptor.
570     /// Returns the number of bytes written to the descriptor chain buffer.
571     /// The number of bytes written can be less than `count` if
572     /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>573     pub fn write_from<F: FileReadWriteVolatile>(
574         &mut self,
575         mut src: F,
576         count: usize,
577     ) -> io::Result<usize> {
578         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
579         let read = src.read_vectored_volatile(&iovs[..])?;
580         self.regions.consume(read);
581         Ok(read)
582     }
583 
584     /// Writes data to the descriptor chain buffer from a File at offset `off`.
585     /// Returns the number of bytes written to the descriptor chain buffer.
586     /// The number of bytes written can be less than `count` if
587     /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, count: usize, off: u64, ) -> io::Result<usize>588     pub fn write_from_at<F: FileReadWriteAtVolatile>(
589         &mut self,
590         mut src: F,
591         count: usize,
592         off: u64,
593     ) -> io::Result<usize> {
594         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
595         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
596         self.regions.consume(read);
597         Ok(read)
598     }
599 
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>600     pub fn write_all_from<F: FileReadWriteVolatile>(
601         &mut self,
602         mut src: F,
603         mut count: usize,
604     ) -> io::Result<()> {
605         while count > 0 {
606             match self.write_from(&mut src, count) {
607                 Ok(0) => {
608                     return Err(io::Error::new(
609                         io::ErrorKind::WriteZero,
610                         "failed to write whole buffer",
611                     ))
612                 }
613                 Ok(n) => count -= n,
614                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
615                 Err(e) => return Err(e),
616             }
617         }
618 
619         Ok(())
620     }
621 
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, mut count: usize, mut off: u64, ) -> io::Result<()>622     pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
623         &mut self,
624         mut src: F,
625         mut count: usize,
626         mut off: u64,
627     ) -> io::Result<()> {
628         while count > 0 {
629             match self.write_from_at(&mut src, count, off) {
630                 Ok(0) => {
631                     return Err(io::Error::new(
632                         io::ErrorKind::WriteZero,
633                         "failed to write whole buffer",
634                     ))
635                 }
636                 Ok(n) => {
637                     count -= n;
638                     off += n as u64;
639                 }
640                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
641                 Err(e) => return Err(e),
642             }
643         }
644         Ok(())
645     }
646     /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
647     /// Returns the number of bytes written to the descriptor chain buffer.
648     /// The number of bytes written can be less than `count` if
649     /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>650     pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
651         &mut self,
652         src: &F,
653         count: usize,
654         off: u64,
655     ) -> disk::Result<usize> {
656         let regions = self.regions.get_remaining_regions_with_count(count);
657         let read = src
658             .read_to_mem(off, Arc::new(self.mem.clone()), &regions)
659             .await?;
660         self.regions.consume(read);
661         Ok(read)
662     }
663 
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>664     pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
665         &mut self,
666         src: &F,
667         mut count: usize,
668         mut off: u64,
669     ) -> disk::Result<()> {
670         while count > 0 {
671             let nwritten = self.write_from_at_fut(src, count, off).await?;
672             if nwritten == 0 {
673                 return Err(disk::Error::WritingData(io::Error::new(
674                     io::ErrorKind::UnexpectedEof,
675                     "failed to write whole buffer",
676                 )));
677             }
678             count -= nwritten;
679             off += nwritten as u64;
680         }
681         Ok(())
682     }
683 
684     /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize685     pub fn bytes_written(&self) -> usize {
686         self.regions.bytes_consumed()
687     }
688 
689     /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
690     /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
691     /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
692     /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer693     pub fn split_at(&mut self, offset: usize) -> Writer {
694         Writer {
695             mem: self.mem.clone(),
696             regions: self.regions.split_at(offset),
697         }
698     }
699 }
700 
701 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>702     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
703         let mut rem = buf;
704         let mut total = 0;
705         for b in self.regions.get_remaining(&self.mem) {
706             if rem.is_empty() {
707                 break;
708             }
709 
710             let count = cmp::min(rem.len(), b.size());
711             // Safe because we have already verified that `vs` points to valid memory.
712             unsafe {
713                 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
714             }
715             rem = &rem[count..];
716             total += count;
717         }
718 
719         self.regions.consume(total);
720         Ok(total)
721     }
722 
flush(&mut self) -> io::Result<()>723     fn flush(&mut self) -> io::Result<()> {
724         // Nothing to flush since the writes go straight into the buffer.
725         Ok(())
726     }
727 }
728 
729 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
730 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
731 
732 #[derive(Copy, Clone, PartialEq, Eq)]
733 pub enum DescriptorType {
734     Readable,
735     Writable,
736 }
737 
738 #[derive(Copy, Clone, Debug)]
739 #[repr(C)]
740 struct virtq_desc {
741     addr: Le64,
742     len: Le32,
743     flags: Le16,
744     next: Le16,
745 }
746 
747 // Safe because it only has data and has no implicit padding.
748 unsafe impl DataInit for virtq_desc {}
749 
750 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> Result<DescriptorChain>751 pub fn create_descriptor_chain(
752     memory: &GuestMemory,
753     descriptor_array_addr: GuestAddress,
754     mut buffers_start_addr: GuestAddress,
755     descriptors: Vec<(DescriptorType, u32)>,
756     spaces_between_regions: u32,
757 ) -> Result<DescriptorChain> {
758     let descriptors_len = descriptors.len();
759     for (index, (type_, size)) in descriptors.into_iter().enumerate() {
760         let mut flags = 0;
761         if let DescriptorType::Writable = type_ {
762             flags |= VIRTQ_DESC_F_WRITE;
763         }
764         if index + 1 < descriptors_len {
765             flags |= VIRTQ_DESC_F_NEXT;
766         }
767 
768         let index = index as u16;
769         let desc = virtq_desc {
770             addr: buffers_start_addr.offset().into(),
771             len: size.into(),
772             flags: flags.into(),
773             next: (index + 1).into(),
774         };
775 
776         let offset = size + spaces_between_regions;
777         buffers_start_addr = buffers_start_addr
778             .checked_add(offset as u64)
779             .ok_or(Error::InvalidChain)?;
780 
781         let _ = memory.write_obj_at_addr(
782             desc,
783             descriptor_array_addr
784                 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
785                 .ok_or(Error::InvalidChain)?,
786         );
787     }
788 
789     DescriptorChain::checked_new(memory, descriptor_array_addr, 0x100, 0, 0)
790         .ok_or(Error::InvalidChain)
791 }
792 
793 #[cfg(test)]
794 mod tests {
795     use super::*;
796     use std::fs::File;
797     use tempfile::tempfile;
798 
799     use cros_async::Executor;
800 
801     #[test]
reader_test_simple_chain()802     fn reader_test_simple_chain() {
803         use DescriptorType::*;
804 
805         let memory_start_addr = GuestAddress(0x0);
806         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
807 
808         let chain = create_descriptor_chain(
809             &memory,
810             GuestAddress(0x0),
811             GuestAddress(0x100),
812             vec![
813                 (Readable, 8),
814                 (Readable, 16),
815                 (Readable, 18),
816                 (Readable, 64),
817             ],
818             0,
819         )
820         .expect("create_descriptor_chain failed");
821         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
822         assert_eq!(reader.available_bytes(), 106);
823         assert_eq!(reader.bytes_read(), 0);
824 
825         let mut buffer = [0 as u8; 64];
826         if let Err(_) = reader.read_exact(&mut buffer) {
827             panic!("read_exact should not fail here");
828         }
829 
830         assert_eq!(reader.available_bytes(), 42);
831         assert_eq!(reader.bytes_read(), 64);
832 
833         match reader.read(&mut buffer) {
834             Err(_) => panic!("read should not fail here"),
835             Ok(length) => assert_eq!(length, 42),
836         }
837 
838         assert_eq!(reader.available_bytes(), 0);
839         assert_eq!(reader.bytes_read(), 106);
840     }
841 
842     #[test]
writer_test_simple_chain()843     fn writer_test_simple_chain() {
844         use DescriptorType::*;
845 
846         let memory_start_addr = GuestAddress(0x0);
847         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
848 
849         let chain = create_descriptor_chain(
850             &memory,
851             GuestAddress(0x0),
852             GuestAddress(0x100),
853             vec![
854                 (Writable, 8),
855                 (Writable, 16),
856                 (Writable, 18),
857                 (Writable, 64),
858             ],
859             0,
860         )
861         .expect("create_descriptor_chain failed");
862         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
863         assert_eq!(writer.available_bytes(), 106);
864         assert_eq!(writer.bytes_written(), 0);
865 
866         let mut buffer = [0 as u8; 64];
867         if let Err(_) = writer.write_all(&mut buffer) {
868             panic!("write_all should not fail here");
869         }
870 
871         assert_eq!(writer.available_bytes(), 42);
872         assert_eq!(writer.bytes_written(), 64);
873 
874         match writer.write(&mut buffer) {
875             Err(_) => panic!("write should not fail here"),
876             Ok(length) => assert_eq!(length, 42),
877         }
878 
879         assert_eq!(writer.available_bytes(), 0);
880         assert_eq!(writer.bytes_written(), 106);
881     }
882 
883     #[test]
reader_test_incompatible_chain()884     fn reader_test_incompatible_chain() {
885         use DescriptorType::*;
886 
887         let memory_start_addr = GuestAddress(0x0);
888         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
889 
890         let chain = create_descriptor_chain(
891             &memory,
892             GuestAddress(0x0),
893             GuestAddress(0x100),
894             vec![(Writable, 8)],
895             0,
896         )
897         .expect("create_descriptor_chain failed");
898         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
899         assert_eq!(reader.available_bytes(), 0);
900         assert_eq!(reader.bytes_read(), 0);
901 
902         assert!(reader.read_obj::<u8>().is_err());
903 
904         assert_eq!(reader.available_bytes(), 0);
905         assert_eq!(reader.bytes_read(), 0);
906     }
907 
908     #[test]
writer_test_incompatible_chain()909     fn writer_test_incompatible_chain() {
910         use DescriptorType::*;
911 
912         let memory_start_addr = GuestAddress(0x0);
913         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
914 
915         let chain = create_descriptor_chain(
916             &memory,
917             GuestAddress(0x0),
918             GuestAddress(0x100),
919             vec![(Readable, 8)],
920             0,
921         )
922         .expect("create_descriptor_chain failed");
923         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
924         assert_eq!(writer.available_bytes(), 0);
925         assert_eq!(writer.bytes_written(), 0);
926 
927         assert!(writer.write_obj(0u8).is_err());
928 
929         assert_eq!(writer.available_bytes(), 0);
930         assert_eq!(writer.bytes_written(), 0);
931     }
932 
933     #[test]
reader_failing_io()934     fn reader_failing_io() {
935         use DescriptorType::*;
936 
937         let memory_start_addr = GuestAddress(0x0);
938         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
939 
940         let chain = create_descriptor_chain(
941             &memory,
942             GuestAddress(0x0),
943             GuestAddress(0x100),
944             vec![(Readable, 256), (Readable, 256)],
945             0,
946         )
947         .expect("create_descriptor_chain failed");
948 
949         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
950 
951         // Open a file in read-only mode so writes to it to trigger an I/O error.
952         let mut ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
953 
954         reader
955             .read_exact_to(&mut ro_file, 512)
956             .expect_err("successfully read more bytes than SharedMemory size");
957 
958         // The write above should have failed entirely, so we end up not writing any bytes at all.
959         assert_eq!(reader.available_bytes(), 512);
960         assert_eq!(reader.bytes_read(), 0);
961     }
962 
963     #[test]
writer_failing_io()964     fn writer_failing_io() {
965         use DescriptorType::*;
966 
967         let memory_start_addr = GuestAddress(0x0);
968         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
969 
970         let chain = create_descriptor_chain(
971             &memory,
972             GuestAddress(0x0),
973             GuestAddress(0x100),
974             vec![(Writable, 256), (Writable, 256)],
975             0,
976         )
977         .expect("create_descriptor_chain failed");
978 
979         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
980 
981         let mut file = tempfile().unwrap();
982 
983         file.set_len(384).unwrap();
984 
985         writer
986             .write_all_from(&mut file, 512)
987             .expect_err("successfully wrote more bytes than in SharedMemory");
988 
989         assert_eq!(writer.available_bytes(), 128);
990         assert_eq!(writer.bytes_written(), 384);
991     }
992 
993     #[test]
reader_writer_shared_chain()994     fn reader_writer_shared_chain() {
995         use DescriptorType::*;
996 
997         let memory_start_addr = GuestAddress(0x0);
998         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
999 
1000         let chain = create_descriptor_chain(
1001             &memory,
1002             GuestAddress(0x0),
1003             GuestAddress(0x100),
1004             vec![
1005                 (Readable, 16),
1006                 (Readable, 16),
1007                 (Readable, 96),
1008                 (Writable, 64),
1009                 (Writable, 1),
1010                 (Writable, 3),
1011             ],
1012             0,
1013         )
1014         .expect("create_descriptor_chain failed");
1015         let mut reader =
1016             Reader::new(memory.clone(), chain.clone()).expect("failed to create Reader");
1017         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1018 
1019         assert_eq!(reader.bytes_read(), 0);
1020         assert_eq!(writer.bytes_written(), 0);
1021 
1022         let mut buffer = Vec::with_capacity(200);
1023 
1024         assert_eq!(
1025             reader
1026                 .read_to_end(&mut buffer)
1027                 .expect("read should not fail here"),
1028             128
1029         );
1030 
1031         // The writable descriptors are only 68 bytes long.
1032         writer
1033             .write_all(&buffer[..68])
1034             .expect("write should not fail here");
1035 
1036         assert_eq!(reader.available_bytes(), 0);
1037         assert_eq!(reader.bytes_read(), 128);
1038         assert_eq!(writer.available_bytes(), 0);
1039         assert_eq!(writer.bytes_written(), 68);
1040     }
1041 
1042     #[test]
reader_writer_shattered_object()1043     fn reader_writer_shattered_object() {
1044         use DescriptorType::*;
1045 
1046         let memory_start_addr = GuestAddress(0x0);
1047         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1048 
1049         let secret: Le32 = 0x12345678.into();
1050 
1051         // Create a descriptor chain with memory regions that are properly separated.
1052         let chain_writer = create_descriptor_chain(
1053             &memory,
1054             GuestAddress(0x0),
1055             GuestAddress(0x100),
1056             vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1057             123,
1058         )
1059         .expect("create_descriptor_chain failed");
1060         let mut writer =
1061             Writer::new(memory.clone(), chain_writer).expect("failed to create Writer");
1062         if let Err(_) = writer.write_obj(secret) {
1063             panic!("write_obj should not fail here");
1064         }
1065 
1066         // Now create new descriptor chain pointing to the same memory and try to read it.
1067         let chain_reader = create_descriptor_chain(
1068             &memory,
1069             GuestAddress(0x0),
1070             GuestAddress(0x100),
1071             vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1072             123,
1073         )
1074         .expect("create_descriptor_chain failed");
1075         let mut reader =
1076             Reader::new(memory.clone(), chain_reader).expect("failed to create Reader");
1077         match reader.read_obj::<Le32>() {
1078             Err(_) => panic!("read_obj should not fail here"),
1079             Ok(read_secret) => assert_eq!(read_secret, secret),
1080         }
1081     }
1082 
1083     #[test]
reader_unexpected_eof()1084     fn reader_unexpected_eof() {
1085         use DescriptorType::*;
1086 
1087         let memory_start_addr = GuestAddress(0x0);
1088         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1089 
1090         let chain = create_descriptor_chain(
1091             &memory,
1092             GuestAddress(0x0),
1093             GuestAddress(0x100),
1094             vec![(Readable, 256), (Readable, 256)],
1095             0,
1096         )
1097         .expect("create_descriptor_chain failed");
1098 
1099         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1100 
1101         let mut buf = Vec::with_capacity(1024);
1102         buf.resize(1024, 0);
1103 
1104         assert_eq!(
1105             reader
1106                 .read_exact(&mut buf[..])
1107                 .expect_err("read more bytes than available")
1108                 .kind(),
1109             io::ErrorKind::UnexpectedEof
1110         );
1111     }
1112 
1113     #[test]
split_border()1114     fn split_border() {
1115         use DescriptorType::*;
1116 
1117         let memory_start_addr = GuestAddress(0x0);
1118         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1119 
1120         let chain = create_descriptor_chain(
1121             &memory,
1122             GuestAddress(0x0),
1123             GuestAddress(0x100),
1124             vec![
1125                 (Readable, 16),
1126                 (Readable, 16),
1127                 (Readable, 96),
1128                 (Writable, 64),
1129                 (Writable, 1),
1130                 (Writable, 3),
1131             ],
1132             0,
1133         )
1134         .expect("create_descriptor_chain failed");
1135         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1136 
1137         let other = reader.split_at(32);
1138         assert_eq!(reader.available_bytes(), 32);
1139         assert_eq!(other.available_bytes(), 96);
1140     }
1141 
1142     #[test]
split_middle()1143     fn split_middle() {
1144         use DescriptorType::*;
1145 
1146         let memory_start_addr = GuestAddress(0x0);
1147         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1148 
1149         let chain = create_descriptor_chain(
1150             &memory,
1151             GuestAddress(0x0),
1152             GuestAddress(0x100),
1153             vec![
1154                 (Readable, 16),
1155                 (Readable, 16),
1156                 (Readable, 96),
1157                 (Writable, 64),
1158                 (Writable, 1),
1159                 (Writable, 3),
1160             ],
1161             0,
1162         )
1163         .expect("create_descriptor_chain failed");
1164         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1165 
1166         let other = reader.split_at(24);
1167         assert_eq!(reader.available_bytes(), 24);
1168         assert_eq!(other.available_bytes(), 104);
1169     }
1170 
1171     #[test]
split_end()1172     fn split_end() {
1173         use DescriptorType::*;
1174 
1175         let memory_start_addr = GuestAddress(0x0);
1176         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1177 
1178         let chain = create_descriptor_chain(
1179             &memory,
1180             GuestAddress(0x0),
1181             GuestAddress(0x100),
1182             vec![
1183                 (Readable, 16),
1184                 (Readable, 16),
1185                 (Readable, 96),
1186                 (Writable, 64),
1187                 (Writable, 1),
1188                 (Writable, 3),
1189             ],
1190             0,
1191         )
1192         .expect("create_descriptor_chain failed");
1193         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1194 
1195         let other = reader.split_at(128);
1196         assert_eq!(reader.available_bytes(), 128);
1197         assert_eq!(other.available_bytes(), 0);
1198     }
1199 
1200     #[test]
split_beginning()1201     fn split_beginning() {
1202         use DescriptorType::*;
1203 
1204         let memory_start_addr = GuestAddress(0x0);
1205         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1206 
1207         let chain = create_descriptor_chain(
1208             &memory,
1209             GuestAddress(0x0),
1210             GuestAddress(0x100),
1211             vec![
1212                 (Readable, 16),
1213                 (Readable, 16),
1214                 (Readable, 96),
1215                 (Writable, 64),
1216                 (Writable, 1),
1217                 (Writable, 3),
1218             ],
1219             0,
1220         )
1221         .expect("create_descriptor_chain failed");
1222         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1223 
1224         let other = reader.split_at(0);
1225         assert_eq!(reader.available_bytes(), 0);
1226         assert_eq!(other.available_bytes(), 128);
1227     }
1228 
1229     #[test]
split_outofbounds()1230     fn split_outofbounds() {
1231         use DescriptorType::*;
1232 
1233         let memory_start_addr = GuestAddress(0x0);
1234         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1235 
1236         let chain = create_descriptor_chain(
1237             &memory,
1238             GuestAddress(0x0),
1239             GuestAddress(0x100),
1240             vec![
1241                 (Readable, 16),
1242                 (Readable, 16),
1243                 (Readable, 96),
1244                 (Writable, 64),
1245                 (Writable, 1),
1246                 (Writable, 3),
1247             ],
1248             0,
1249         )
1250         .expect("create_descriptor_chain failed");
1251         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1252 
1253         let other = reader.split_at(256);
1254         assert_eq!(
1255             other.available_bytes(),
1256             0,
1257             "Reader returned from out-of-bounds split still has available bytes"
1258         );
1259     }
1260 
1261     #[test]
read_full()1262     fn read_full() {
1263         use DescriptorType::*;
1264 
1265         let memory_start_addr = GuestAddress(0x0);
1266         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1267 
1268         let chain = create_descriptor_chain(
1269             &memory,
1270             GuestAddress(0x0),
1271             GuestAddress(0x100),
1272             vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1273             0,
1274         )
1275         .expect("create_descriptor_chain failed");
1276         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1277 
1278         let mut buf = vec![0u8; 64];
1279         assert_eq!(
1280             reader.read(&mut buf[..]).expect("failed to read to buffer"),
1281             48
1282         );
1283     }
1284 
1285     #[test]
write_full()1286     fn write_full() {
1287         use DescriptorType::*;
1288 
1289         let memory_start_addr = GuestAddress(0x0);
1290         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1291 
1292         let chain = create_descriptor_chain(
1293             &memory,
1294             GuestAddress(0x0),
1295             GuestAddress(0x100),
1296             vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1297             0,
1298         )
1299         .expect("create_descriptor_chain failed");
1300         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1301 
1302         let buf = vec![0xdeu8; 64];
1303         assert_eq!(
1304             writer.write(&buf[..]).expect("failed to write from buffer"),
1305             48
1306         );
1307     }
1308 
1309     #[test]
consume_collect()1310     fn consume_collect() {
1311         use DescriptorType::*;
1312 
1313         let memory_start_addr = GuestAddress(0x0);
1314         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1315         let vs: Vec<Le64> = vec![
1316             0x0101010101010101.into(),
1317             0x0202020202020202.into(),
1318             0x0303030303030303.into(),
1319         ];
1320 
1321         let write_chain = create_descriptor_chain(
1322             &memory,
1323             GuestAddress(0x0),
1324             GuestAddress(0x100),
1325             vec![(Writable, 24)],
1326             0,
1327         )
1328         .expect("create_descriptor_chain failed");
1329         let mut writer = Writer::new(memory.clone(), write_chain).expect("failed to create Writer");
1330         writer
1331             .consume(vs.clone())
1332             .expect("failed to consume() a vector");
1333 
1334         let read_chain = create_descriptor_chain(
1335             &memory,
1336             GuestAddress(0x0),
1337             GuestAddress(0x100),
1338             vec![(Readable, 24)],
1339             0,
1340         )
1341         .expect("create_descriptor_chain failed");
1342         let mut reader = Reader::new(memory.clone(), read_chain).expect("failed to create Reader");
1343         let vs_read = reader
1344             .collect::<io::Result<Vec<Le64>>, _>()
1345             .expect("failed to collect() values");
1346         assert_eq!(vs, vs_read);
1347     }
1348 
1349     #[test]
get_remaining_region_with_count()1350     fn get_remaining_region_with_count() {
1351         use DescriptorType::*;
1352 
1353         let memory_start_addr = GuestAddress(0x0);
1354         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1355 
1356         let chain = create_descriptor_chain(
1357             &memory,
1358             GuestAddress(0x0),
1359             GuestAddress(0x100),
1360             vec![
1361                 (Readable, 16),
1362                 (Readable, 16),
1363                 (Readable, 96),
1364                 (Writable, 64),
1365                 (Writable, 1),
1366                 (Writable, 3),
1367             ],
1368             0,
1369         )
1370         .expect("create_descriptor_chain failed");
1371 
1372         let Reader {
1373             mem: _,
1374             mut regions,
1375         } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1376 
1377         let drain = regions
1378             .get_remaining_regions_with_count(::std::usize::MAX)
1379             .iter()
1380             .fold(0usize, |total, region| total + region.len);
1381         assert_eq!(drain, 128);
1382 
1383         let exact = regions
1384             .get_remaining_regions_with_count(32)
1385             .iter()
1386             .fold(0usize, |total, region| total + region.len);
1387         assert!(exact > 0);
1388         assert!(exact <= 32);
1389 
1390         let split = regions
1391             .get_remaining_regions_with_count(24)
1392             .iter()
1393             .fold(0usize, |total, region| total + region.len);
1394         assert!(split > 0);
1395         assert!(split <= 24);
1396 
1397         regions.consume(64);
1398 
1399         let first = regions
1400             .get_remaining_regions_with_count(8)
1401             .iter()
1402             .fold(0usize, |total, region| total + region.len);
1403         assert!(first > 0);
1404         assert!(first <= 8);
1405     }
1406 
1407     #[test]
get_remaining_with_count()1408     fn get_remaining_with_count() {
1409         use DescriptorType::*;
1410 
1411         let memory_start_addr = GuestAddress(0x0);
1412         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1413 
1414         let chain = create_descriptor_chain(
1415             &memory,
1416             GuestAddress(0x0),
1417             GuestAddress(0x100),
1418             vec![
1419                 (Readable, 16),
1420                 (Readable, 16),
1421                 (Readable, 96),
1422                 (Writable, 64),
1423                 (Writable, 1),
1424                 (Writable, 3),
1425             ],
1426             0,
1427         )
1428         .expect("create_descriptor_chain failed");
1429         let Reader {
1430             mem: _,
1431             mut regions,
1432         } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1433 
1434         let drain = regions
1435             .get_remaining_with_count(&memory, ::std::usize::MAX)
1436             .iter()
1437             .fold(0usize, |total, iov| total + iov.size());
1438         assert_eq!(drain, 128);
1439 
1440         let exact = regions
1441             .get_remaining_with_count(&memory, 32)
1442             .iter()
1443             .fold(0usize, |total, iov| total + iov.size());
1444         assert!(exact > 0);
1445         assert!(exact <= 32);
1446 
1447         let split = regions
1448             .get_remaining_with_count(&memory, 24)
1449             .iter()
1450             .fold(0usize, |total, iov| total + iov.size());
1451         assert!(split > 0);
1452         assert!(split <= 24);
1453 
1454         regions.consume(64);
1455 
1456         let first = regions
1457             .get_remaining_with_count(&memory, 8)
1458             .iter()
1459             .fold(0usize, |total, iov| total + iov.size());
1460         assert!(first > 0);
1461         assert!(first <= 8);
1462     }
1463 
1464     #[test]
region_reader_failing_io()1465     fn region_reader_failing_io() {
1466         let ex = Executor::new().unwrap();
1467         ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1468     }
region_reader_failing_io_async(ex: &Executor)1469     async fn region_reader_failing_io_async(ex: &Executor) {
1470         use DescriptorType::*;
1471 
1472         let memory_start_addr = GuestAddress(0x0);
1473         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1474 
1475         let chain = create_descriptor_chain(
1476             &memory,
1477             GuestAddress(0x0),
1478             GuestAddress(0x100),
1479             vec![(Readable, 256), (Readable, 256)],
1480             0,
1481         )
1482         .expect("create_descriptor_chain failed");
1483 
1484         let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1485 
1486         // Open a file in read-only mode so writes to it to trigger an I/O error.
1487         let ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
1488         let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1489 
1490         reader
1491             .read_exact_to_at_fut(&async_ro_file, 512, 0)
1492             .await
1493             .expect_err("successfully read more bytes than SharedMemory size");
1494 
1495         // The write above should have failed entirely, so we end up not writing any bytes at all.
1496         assert_eq!(reader.available_bytes(), 512);
1497         assert_eq!(reader.bytes_read(), 0);
1498     }
1499 
1500     #[test]
region_writer_failing_io()1501     fn region_writer_failing_io() {
1502         let ex = Executor::new().unwrap();
1503         ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1504     }
region_writer_failing_io_async(ex: &Executor)1505     async fn region_writer_failing_io_async(ex: &Executor) {
1506         use DescriptorType::*;
1507 
1508         let memory_start_addr = GuestAddress(0x0);
1509         let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1510 
1511         let chain = create_descriptor_chain(
1512             &memory,
1513             GuestAddress(0x0),
1514             GuestAddress(0x100),
1515             vec![(Writable, 256), (Writable, 256)],
1516             0,
1517         )
1518         .expect("create_descriptor_chain failed");
1519 
1520         let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1521 
1522         let file = tempfile().expect("failed to create temp file");
1523 
1524         file.set_len(384).unwrap();
1525         let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1526 
1527         writer
1528             .write_all_from_at_fut(&async_file, 512, 0)
1529             .await
1530             .expect_err("successfully wrote more bytes than in SharedMemory");
1531 
1532         assert_eq!(writer.available_bytes(), 128);
1533         assert_eq!(writer.bytes_written(), 384);
1534     }
1535 }
1536