• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 mod sys;
6 
7 use std::borrow::Cow;
8 use std::cmp;
9 use std::convert::TryInto;
10 use std::io;
11 use std::io::Write;
12 use std::iter::FromIterator;
13 use std::marker::PhantomData;
14 use std::ptr::copy_nonoverlapping;
15 use std::result;
16 use std::sync::Arc;
17 
18 use anyhow::Context;
19 use base::FileReadWriteAtVolatile;
20 use base::FileReadWriteVolatile;
21 use cros_async::MemRegion;
22 use data_model::zerocopy_from_reader;
23 use data_model::Le16;
24 use data_model::Le32;
25 use data_model::Le64;
26 use data_model::VolatileMemoryError;
27 use data_model::VolatileSlice;
28 use disk::AsyncDisk;
29 use remain::sorted;
30 use smallvec::SmallVec;
31 use thiserror::Error;
32 use vm_memory::GuestAddress;
33 use vm_memory::GuestMemory;
34 use zerocopy::AsBytes;
35 use zerocopy::FromBytes;
36 
37 use super::DescriptorChain;
38 use crate::virtio::ipc_memory_mapper::ExportedRegion;
39 
40 #[sorted]
41 #[derive(Error, Debug)]
42 pub enum Error {
43     #[error("the combined length of all the buffers in a `DescriptorChain` would overflow")]
44     DescriptorChainOverflow,
45     #[error("descriptor guest memory error: {0}")]
46     GuestMemoryError(vm_memory::GuestMemoryError),
47     #[error("invalid descriptor chain: {0:#}")]
48     InvalidChain(anyhow::Error),
49     #[error("descriptor I/O error: {0}")]
50     IoError(io::Error),
51     #[error("`DescriptorChain` split is out of bounds: {0}")]
52     SplitOutOfBounds(usize),
53     #[error("volatile memory error: {0}")]
54     VolatileMemoryError(VolatileMemoryError),
55 }
56 
57 pub type Result<T> = result::Result<T, Error>;
58 
59 #[derive(Clone)]
60 struct DescriptorChainRegions {
61     regions: DescriptorChainMemRegions,
62     current: usize,
63     bytes_consumed: usize,
64 }
65 
66 impl DescriptorChainRegions {
available_bytes(&self) -> usize67     fn available_bytes(&self) -> usize {
68         // This is guaranteed not to overflow because the total length of the chain
69         // is checked during all creations of `DescriptorChainRegions` (see
70         // `Reader::new()` and `Writer::new()`).
71         self.get_remaining_regions()
72             .iter()
73             .fold(0usize, |count, region| count + region.len)
74     }
75 
bytes_consumed(&self) -> usize76     fn bytes_consumed(&self) -> usize {
77         self.bytes_consumed
78     }
79 
80     /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
81     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
82     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
83     /// to `consume` will return the same data.
get_remaining_regions(&self) -> &[MemRegion]84     fn get_remaining_regions(&self) -> &[MemRegion] {
85         &self.regions.regions[self.current..]
86     }
87 
88     /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
89     /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
90     /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
91     /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>92     fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
93         self.get_remaining_regions()
94             .iter()
95             .filter_map(|region| {
96                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
97                     .ok()
98             })
99             .collect()
100     }
101 
102     /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is
103     /// not greater than `count`. The combined length of the returned iovecs may be less than
104     /// `count` but will always be greater than 0 as long as there is still space left in the
105     /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]>106     fn get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]> {
107         let regions = self.get_remaining_regions();
108         let mut region_count = 0;
109         let mut rem = count;
110         for region in regions {
111             if rem < region.len {
112                 break;
113             }
114 
115             region_count += 1;
116             rem -= region.len;
117         }
118 
119         // Special case where the number of bytes to be copied is smaller than the `size()` of the
120         // first regions.
121         if region_count == 0 && !regions.is_empty() && count > 0 {
122             debug_assert!(count < regions[0].len);
123             // Safe because we know that count is smaller than the length of the first slice.
124             Cow::Owned(vec![MemRegion {
125                 offset: regions[0].offset,
126                 len: count,
127             }])
128         } else {
129             Cow::Borrowed(&regions[..region_count])
130         }
131     }
132 
133     /// Like 'get_remaining_with_count' except convert the offsets to volatile slices in the
134     /// 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>135     fn get_remaining_with_count<'mem>(
136         &self,
137         mem: &'mem GuestMemory,
138         count: usize,
139     ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
140         self.get_remaining_regions_with_count(count)
141             .iter()
142             .filter_map(|region| {
143                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
144                     .ok()
145             })
146             .collect()
147     }
148 
149     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
150     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)151     fn consume(&mut self, mut count: usize) {
152         // The implementation is adapted from `IoSlice::advance` in libstd. We can't use
153         // `get_remaining` here because then the compiler complains that `self.current` is already
154         // borrowed and doesn't allow us to modify it.  We also need to borrow the iovecs mutably.
155         let current = self.current;
156         for region in &mut self.regions.regions[current..] {
157             if count == 0 {
158                 break;
159             }
160 
161             let consumed = if count < region.len {
162                 // Safe because we know that the iovec pointed to valid memory and we are adding a
163                 // value that is smaller than the length of the memory.
164                 *region = MemRegion {
165                     offset: region.offset + count as u64,
166                     len: region.len - count,
167                 };
168                 count
169             } else {
170                 self.current += 1;
171                 region.len
172             };
173 
174             // This shouldn't overflow because `consumed <= buf.size()` and we already verified
175             // that adding all `buf.size()` values will not overflow when the Reader/Writer was
176             // constructed.
177             self.bytes_consumed += consumed;
178             count -= consumed;
179         }
180     }
181 
split_at(&mut self, offset: usize) -> DescriptorChainRegions182     fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
183         let mut other = self.clone();
184         other.consume(offset);
185         other.bytes_consumed = 0;
186 
187         let mut rem = offset;
188         let mut end = self.current;
189         for region in &mut self.regions.regions[self.current..] {
190             if rem <= region.len {
191                 region.len = rem;
192                 break;
193             }
194 
195             end += 1;
196             rem -= region.len;
197         }
198 
199         self.regions.regions.truncate(end + 1);
200 
201         other
202     }
203 }
204 
205 /// Provides high-level interface over the sequence of memory regions
206 /// defined by readable descriptors in the descriptor chain.
207 ///
208 /// Note that virtio spec requires driver to place any device-writable
209 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
210 /// Reader will skip iterating over descriptor chain when first writable
211 /// descriptor is encountered.
212 #[derive(Clone)]
213 pub struct Reader {
214     mem: GuestMemory,
215     regions: DescriptorChainRegions,
216 }
217 
218 /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
219 pub struct ReaderIterator<'a, T: FromBytes> {
220     reader: &'a mut Reader,
221     phantom: PhantomData<T>,
222 }
223 
224 impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> {
225     type Item = io::Result<T>;
226 
next(&mut self) -> Option<io::Result<T>>227     fn next(&mut self) -> Option<io::Result<T>> {
228         if self.reader.available_bytes() == 0 {
229             None
230         } else {
231             Some(self.reader.read_obj())
232         }
233     }
234 }
235 
236 #[derive(Clone)]
237 pub struct DescriptorChainMemRegions {
238     pub regions: SmallVec<[cros_async::MemRegion; 16]>,
239     // For virtio devices that operate on IOVAs rather than guest phyiscal
240     // addresses, the IOVA regions must be exported from virtio-iommu to get
241     // the underlying memory regions. It is only valid for the virtio device
242     // to access those memory regions while they remain exported, so maintain
243     // references to the exported regions until the descriptor chain is
244     // dropped.
245     _exported_regions: Option<Vec<ExportedRegion>>,
246 }
247 
248 /// Get all the mem regions from a `DescriptorChain` iterator, regardless if the `DescriptorChain`
249 /// contains GPAs (guest physical address), or IOVAs (io virtual address). IOVAs will
250 /// be translated to GPAs via IOMMU.
get_mem_regions<I>(mem: &GuestMemory, vals: I) -> Result<DescriptorChainMemRegions> where I: Iterator<Item = DescriptorChain>,251 pub fn get_mem_regions<I>(mem: &GuestMemory, vals: I) -> Result<DescriptorChainMemRegions>
252 where
253     I: Iterator<Item = DescriptorChain>,
254 {
255     let mut total_len: usize = 0;
256     let mut regions = SmallVec::new();
257     let mut exported_regions: Option<Vec<ExportedRegion>> = None;
258 
259     // TODO(jstaron): Update this code to take the indirect descriptors into account.
260     for desc in vals {
261         // Verify that summing the descriptor sizes does not overflow.
262         // This can happen if a driver tricks a device into reading/writing more data than
263         // fits in a `usize`.
264         total_len = total_len
265             .checked_add(desc.len as usize)
266             .ok_or(Error::DescriptorChainOverflow)?;
267 
268         let (desc_regions, exported) = desc.into_mem_regions();
269         for r in desc_regions {
270             // Check that all the regions are totally contained in GuestMemory.
271             mem.get_slice_at_addr(r.gpa, r.len.try_into().expect("u32 doesn't fit in usize"))
272                 .map_err(Error::GuestMemoryError)?;
273 
274             regions.push(cros_async::MemRegion {
275                 offset: r.gpa.offset(),
276                 len: r.len.try_into().expect("u32 doesn't fit in usize"),
277             });
278         }
279         if let Some(exported) = exported {
280             exported_regions.get_or_insert(vec![]).push(exported);
281         }
282     }
283 
284     Ok(DescriptorChainMemRegions {
285         regions,
286         _exported_regions: exported_regions,
287     })
288 }
289 
290 impl Reader {
291     /// Construct a new Reader wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader>292     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader> {
293         let regions = get_mem_regions(&mem, desc_chain.into_iter().readable())?;
294         Ok(Reader {
295             mem,
296             regions: DescriptorChainRegions {
297                 regions,
298                 current: 0,
299                 bytes_consumed: 0,
300             },
301         })
302     }
303 
304     /// Reads an object from the descriptor chain buffer.
read_obj<T: FromBytes>(&mut self) -> io::Result<T>305     pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
306         zerocopy_from_reader(self)
307     }
308 
309     /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
310     /// them as a collection. Returns an error if the size of the remaining data is indivisible by
311     /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C312     pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
313         self.iter().collect()
314     }
315 
316     /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
317     /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
318     /// doesn't require the objects to be stored in a separate collection.
iter<T: FromBytes>(&mut self) -> ReaderIterator<T>319     pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
320         ReaderIterator {
321             reader: self,
322             phantom: PhantomData,
323         }
324     }
325 
326     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
327     /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize328     pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
329         let mut read = 0usize;
330         let mut dst = slice;
331         for src in self.get_remaining() {
332             src.copy_to_volatile_slice(dst);
333             let copied = std::cmp::min(src.size(), dst.size());
334             read += copied;
335             dst = match dst.offset(copied) {
336                 Ok(v) => v,
337                 Err(_) => break, // The slice is fully consumed
338             };
339         }
340         self.regions.consume(read);
341         read
342     }
343 
344     /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
345     /// `cb`.
read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>( &mut self, cb: C, count: usize, ) -> usize346     pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
347         &mut self,
348         cb: C,
349         count: usize,
350     ) -> usize {
351         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
352         let written = cb(&iovs[..]);
353         self.regions.consume(written);
354         written
355     }
356 
357     /// Reads data from the descriptor chain buffer into a writable object.
358     /// Returns the number of bytes read from the descriptor chain buffer.
359     /// The number of bytes read can be less than `count` if there isn't
360     /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>361     pub fn read_to<F: FileReadWriteVolatile>(
362         &mut self,
363         mut dst: F,
364         count: usize,
365     ) -> io::Result<usize> {
366         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
367         let written = dst.write_vectored_volatile(&iovs[..])?;
368         self.regions.consume(written);
369         Ok(written)
370     }
371 
372     /// Reads data from the descriptor chain buffer into a File at offset `off`.
373     /// Returns the number of bytes read from the descriptor chain buffer.
374     /// The number of bytes read can be less than `count` if there isn't
375     /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, count: usize, off: u64, ) -> io::Result<usize>376     pub fn read_to_at<F: FileReadWriteAtVolatile>(
377         &mut self,
378         mut dst: F,
379         count: usize,
380         off: u64,
381     ) -> io::Result<usize> {
382         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
383         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
384         self.regions.consume(written);
385         Ok(written)
386     }
387 
388     /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
389     /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>390     pub fn read_exact_to<F: FileReadWriteVolatile>(
391         &mut self,
392         mut dst: F,
393         mut count: usize,
394     ) -> io::Result<()> {
395         while count > 0 {
396             match self.read_to(&mut dst, count) {
397                 Ok(0) => {
398                     return Err(io::Error::new(
399                         io::ErrorKind::UnexpectedEof,
400                         "failed to fill whole buffer",
401                     ))
402                 }
403                 Ok(n) => count -= n,
404                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
405                 Err(e) => return Err(e),
406             }
407         }
408 
409         Ok(())
410     }
411 
412     /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
413     /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, mut count: usize, mut off: u64, ) -> io::Result<()>414     pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
415         &mut self,
416         mut dst: F,
417         mut count: usize,
418         mut off: u64,
419     ) -> io::Result<()> {
420         while count > 0 {
421             match self.read_to_at(&mut dst, count, off) {
422                 Ok(0) => {
423                     return Err(io::Error::new(
424                         io::ErrorKind::UnexpectedEof,
425                         "failed to fill whole buffer",
426                     ))
427                 }
428                 Ok(n) => {
429                     count -= n;
430                     off += n as u64;
431                 }
432                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
433                 Err(e) => return Err(e),
434             }
435         }
436 
437         Ok(())
438     }
439 
440     /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
441     /// Returns the number of bytes read from the descriptor chain buffer.
442     /// The number of bytes read can be less than `count` if there isn't
443     /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>444     pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
445         &mut self,
446         dst: &F,
447         count: usize,
448         off: u64,
449     ) -> disk::Result<usize> {
450         let mem_regions = self.regions.get_remaining_regions_with_count(count);
451         let written = dst
452             .write_from_mem(off, Arc::new(self.mem.clone()), &mem_regions)
453             .await?;
454         self.regions.consume(written);
455         Ok(written)
456     }
457 
458     /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
459     /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>460     pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
461         &mut self,
462         dst: &F,
463         mut count: usize,
464         mut off: u64,
465     ) -> disk::Result<()> {
466         while count > 0 {
467             let nread = self.read_to_at_fut(dst, count, off).await?;
468             if nread == 0 {
469                 return Err(disk::Error::ReadingData(io::Error::new(
470                     io::ErrorKind::UnexpectedEof,
471                     "failed to write whole buffer",
472                 )));
473             }
474             count -= nread;
475             off += nread as u64;
476         }
477 
478         Ok(())
479     }
480 
481     /// Returns number of bytes available for reading.  May return an error if the combined
482     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize483     pub fn available_bytes(&self) -> usize {
484         self.regions.available_bytes()
485     }
486 
487     /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize488     pub fn bytes_read(&self) -> usize {
489         self.regions.bytes_consumed()
490     }
491 
492     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
493     /// Calling this method does not actually consume any data from the `Reader` and callers should
494     /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>495     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
496         self.regions.get_remaining(&self.mem)
497     }
498 
499     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
500     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)501     pub fn consume(&mut self, amt: usize) {
502         self.regions.consume(amt)
503     }
504 
505     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
506     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
507     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
508     /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader509     pub fn split_at(&mut self, offset: usize) -> Reader {
510         Reader {
511             mem: self.mem.clone(),
512             regions: self.regions.split_at(offset),
513         }
514     }
515 }
516 
517 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>518     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
519         let mut rem = buf;
520         let mut total = 0;
521         for b in self.regions.get_remaining(&self.mem) {
522             if rem.is_empty() {
523                 break;
524             }
525 
526             let count = cmp::min(rem.len(), b.size());
527 
528             // Safe because we have already verified that `b` points to valid memory.
529             unsafe {
530                 copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count);
531             }
532             rem = &mut rem[count..];
533             total += count;
534         }
535 
536         self.regions.consume(total);
537         Ok(total)
538     }
539 }
540 
541 /// Provides high-level interface over the sequence of memory regions
542 /// defined by writable descriptors in the descriptor chain.
543 ///
544 /// Note that virtio spec requires driver to place any device-writable
545 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
546 /// Writer will start iterating the descriptors from the first writable one and will
547 /// assume that all following descriptors are writable.
548 #[derive(Clone)]
549 pub struct Writer {
550     mem: GuestMemory,
551     regions: DescriptorChainRegions,
552 }
553 
554 impl Writer {
555     /// Construct a new Writer wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer>556     pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer> {
557         let regions = get_mem_regions(&mem, desc_chain.into_iter().writable())?;
558         Ok(Writer {
559             mem,
560             regions: DescriptorChainRegions {
561                 regions,
562                 current: 0,
563                 bytes_consumed: 0,
564             },
565         })
566     }
567 
568     /// Writes an object to the descriptor chain buffer.
write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()>569     pub fn write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()> {
570         self.write_all(val.as_bytes())
571     }
572 
573     /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
574     /// this doesn't require the values to be stored in an intermediate collection first. It also
575     /// allows callers to choose which elements in a collection to write, for example by using the
576     /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()>577     pub fn write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()> {
578         iter.try_for_each(|v| self.write_obj(v))
579     }
580 
581     /// Writes a collection of objects into the descriptor chain buffer.
consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>582     pub fn consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
583         self.write_iter(vals.into_iter())
584     }
585 
586     /// Returns number of bytes available for writing.  May return an error if the combined
587     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize588     pub fn available_bytes(&self) -> usize {
589         self.regions.available_bytes()
590     }
591 
592     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
593     /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize594     pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
595         let mut written = 0usize;
596         let mut src = slice;
597         for dst in self.get_remaining() {
598             src.copy_to_volatile_slice(dst);
599             let copied = std::cmp::min(src.size(), dst.size());
600             written += copied;
601             src = match src.offset(copied) {
602                 Ok(v) => v,
603                 Err(_) => break, // The slice is fully consumed
604             };
605         }
606         self.regions.consume(written);
607         written
608     }
609 
610     /// Writes data to the descriptor chain buffer from a readable object.
611     /// Returns the number of bytes written to the descriptor chain buffer.
612     /// The number of bytes written can be less than `count` if
613     /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>614     pub fn write_from<F: FileReadWriteVolatile>(
615         &mut self,
616         mut src: F,
617         count: usize,
618     ) -> io::Result<usize> {
619         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
620         let read = src.read_vectored_volatile(&iovs[..])?;
621         self.regions.consume(read);
622         Ok(read)
623     }
624 
625     /// Writes data to the descriptor chain buffer from a File at offset `off`.
626     /// Returns the number of bytes written to the descriptor chain buffer.
627     /// The number of bytes written can be less than `count` if
628     /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, count: usize, off: u64, ) -> io::Result<usize>629     pub fn write_from_at<F: FileReadWriteAtVolatile>(
630         &mut self,
631         mut src: F,
632         count: usize,
633         off: u64,
634     ) -> io::Result<usize> {
635         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
636         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
637         self.regions.consume(read);
638         Ok(read)
639     }
640 
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>641     pub fn write_all_from<F: FileReadWriteVolatile>(
642         &mut self,
643         mut src: F,
644         mut count: usize,
645     ) -> io::Result<()> {
646         while count > 0 {
647             match self.write_from(&mut src, count) {
648                 Ok(0) => {
649                     return Err(io::Error::new(
650                         io::ErrorKind::WriteZero,
651                         "failed to write whole buffer",
652                     ))
653                 }
654                 Ok(n) => count -= n,
655                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
656                 Err(e) => return Err(e),
657             }
658         }
659 
660         Ok(())
661     }
662 
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, mut count: usize, mut off: u64, ) -> io::Result<()>663     pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
664         &mut self,
665         mut src: F,
666         mut count: usize,
667         mut off: u64,
668     ) -> io::Result<()> {
669         while count > 0 {
670             match self.write_from_at(&mut src, count, off) {
671                 Ok(0) => {
672                     return Err(io::Error::new(
673                         io::ErrorKind::WriteZero,
674                         "failed to write whole buffer",
675                     ))
676                 }
677                 Ok(n) => {
678                     count -= n;
679                     off += n as u64;
680                 }
681                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
682                 Err(e) => return Err(e),
683             }
684         }
685         Ok(())
686     }
687     /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
688     /// Returns the number of bytes written to the descriptor chain buffer.
689     /// The number of bytes written can be less than `count` if
690     /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>691     pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
692         &mut self,
693         src: &F,
694         count: usize,
695         off: u64,
696     ) -> disk::Result<usize> {
697         let regions = self.regions.get_remaining_regions_with_count(count);
698         let read = src
699             .read_to_mem(off, Arc::new(self.mem.clone()), &regions)
700             .await?;
701         self.regions.consume(read);
702         Ok(read)
703     }
704 
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>705     pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
706         &mut self,
707         src: &F,
708         mut count: usize,
709         mut off: u64,
710     ) -> disk::Result<()> {
711         while count > 0 {
712             let nwritten = self.write_from_at_fut(src, count, off).await?;
713             if nwritten == 0 {
714                 return Err(disk::Error::WritingData(io::Error::new(
715                     io::ErrorKind::UnexpectedEof,
716                     "failed to write whole buffer",
717                 )));
718             }
719             count -= nwritten;
720             off += nwritten as u64;
721         }
722         Ok(())
723     }
724 
725     /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize726     pub fn bytes_written(&self) -> usize {
727         self.regions.bytes_consumed()
728     }
729 
730     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
731     /// Calling this method does not actually advance the current position of the `Writer` in the
732     /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
733     /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
734     /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>735     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
736         self.regions.get_remaining(&self.mem)
737     }
738 
739     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
740     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)741     pub fn consume_bytes(&mut self, amt: usize) {
742         self.regions.consume(amt)
743     }
744 
745     /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
746     /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
747     /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
748     /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer749     pub fn split_at(&mut self, offset: usize) -> Writer {
750         Writer {
751             mem: self.mem.clone(),
752             regions: self.regions.split_at(offset),
753         }
754     }
755 }
756 
757 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>758     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
759         let mut rem = buf;
760         let mut total = 0;
761         for b in self.regions.get_remaining(&self.mem) {
762             if rem.is_empty() {
763                 break;
764             }
765 
766             let count = cmp::min(rem.len(), b.size());
767             // Safe because we have already verified that `vs` points to valid memory.
768             unsafe {
769                 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
770             }
771             rem = &rem[count..];
772             total += count;
773         }
774 
775         self.regions.consume(total);
776         Ok(total)
777     }
778 
flush(&mut self) -> io::Result<()>779     fn flush(&mut self) -> io::Result<()> {
780         // Nothing to flush since the writes go straight into the buffer.
781         Ok(())
782     }
783 }
784 
785 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
786 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
787 
788 #[derive(Copy, Clone, PartialEq, Eq)]
789 pub enum DescriptorType {
790     Readable,
791     Writable,
792 }
793 
794 #[derive(Copy, Clone, Debug, FromBytes, AsBytes)]
795 #[repr(C)]
796 struct virtq_desc {
797     addr: Le64,
798     len: Le32,
799     flags: Le16,
800     next: Le16,
801 }
802 
803 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> Result<DescriptorChain>804 pub fn create_descriptor_chain(
805     memory: &GuestMemory,
806     descriptor_array_addr: GuestAddress,
807     mut buffers_start_addr: GuestAddress,
808     descriptors: Vec<(DescriptorType, u32)>,
809     spaces_between_regions: u32,
810 ) -> Result<DescriptorChain> {
811     let descriptors_len = descriptors.len();
812     for (index, (type_, size)) in descriptors.into_iter().enumerate() {
813         let mut flags = 0;
814         if let DescriptorType::Writable = type_ {
815             flags |= VIRTQ_DESC_F_WRITE;
816         }
817         if index + 1 < descriptors_len {
818             flags |= VIRTQ_DESC_F_NEXT;
819         }
820 
821         let index = index as u16;
822         let desc = virtq_desc {
823             addr: buffers_start_addr.offset().into(),
824             len: size.into(),
825             flags: flags.into(),
826             next: (index + 1).into(),
827         };
828 
829         let offset = size + spaces_between_regions;
830         buffers_start_addr = buffers_start_addr
831             .checked_add(offset as u64)
832             .context("Invalid buffers_start_addr)")
833             .map_err(Error::InvalidChain)?;
834 
835         let _ = memory.write_obj_at_addr(
836             desc,
837             descriptor_array_addr
838                 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
839                 .context("Invalid descriptor_array_addr")
840                 .map_err(Error::InvalidChain)?,
841         );
842     }
843 
844     DescriptorChain::checked_new(memory, descriptor_array_addr, 0x100, 0, 0, None, None)
845         .map_err(Error::InvalidChain)
846 }
847 
848 #[cfg(test)]
849 mod tests {
850     use std::fs::File;
851     use std::io::Read;
852 
853     use tempfile::tempfile;
854 
855     use super::*;
856 
857     #[test]
reader_test_simple_chain()858     fn reader_test_simple_chain() {
859         use DescriptorType::*;
860 
861         let memory_start_addr = GuestAddress(0x0);
862         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
863 
864         let chain = create_descriptor_chain(
865             &memory,
866             GuestAddress(0x0),
867             GuestAddress(0x100),
868             vec![
869                 (Readable, 8),
870                 (Readable, 16),
871                 (Readable, 18),
872                 (Readable, 64),
873             ],
874             0,
875         )
876         .expect("create_descriptor_chain failed");
877         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
878         assert_eq!(reader.available_bytes(), 106);
879         assert_eq!(reader.bytes_read(), 0);
880 
881         let mut buffer = [0u8; 64];
882         reader
883             .read_exact(&mut buffer)
884             .expect("read_exact should not fail here");
885 
886         assert_eq!(reader.available_bytes(), 42);
887         assert_eq!(reader.bytes_read(), 64);
888 
889         match reader.read(&mut buffer) {
890             Err(_) => panic!("read should not fail here"),
891             Ok(length) => assert_eq!(length, 42),
892         }
893 
894         assert_eq!(reader.available_bytes(), 0);
895         assert_eq!(reader.bytes_read(), 106);
896     }
897 
898     #[test]
writer_test_simple_chain()899     fn writer_test_simple_chain() {
900         use DescriptorType::*;
901 
902         let memory_start_addr = GuestAddress(0x0);
903         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
904 
905         let chain = create_descriptor_chain(
906             &memory,
907             GuestAddress(0x0),
908             GuestAddress(0x100),
909             vec![
910                 (Writable, 8),
911                 (Writable, 16),
912                 (Writable, 18),
913                 (Writable, 64),
914             ],
915             0,
916         )
917         .expect("create_descriptor_chain failed");
918         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
919         assert_eq!(writer.available_bytes(), 106);
920         assert_eq!(writer.bytes_written(), 0);
921 
922         let buffer = [0; 64];
923         writer
924             .write_all(&buffer)
925             .expect("write_all should not fail here");
926 
927         assert_eq!(writer.available_bytes(), 42);
928         assert_eq!(writer.bytes_written(), 64);
929 
930         match writer.write(&buffer) {
931             Err(_) => panic!("write should not fail here"),
932             Ok(length) => assert_eq!(length, 42),
933         }
934 
935         assert_eq!(writer.available_bytes(), 0);
936         assert_eq!(writer.bytes_written(), 106);
937     }
938 
939     #[test]
reader_test_incompatible_chain()940     fn reader_test_incompatible_chain() {
941         use DescriptorType::*;
942 
943         let memory_start_addr = GuestAddress(0x0);
944         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
945 
946         let chain = create_descriptor_chain(
947             &memory,
948             GuestAddress(0x0),
949             GuestAddress(0x100),
950             vec![(Writable, 8)],
951             0,
952         )
953         .expect("create_descriptor_chain failed");
954         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
955         assert_eq!(reader.available_bytes(), 0);
956         assert_eq!(reader.bytes_read(), 0);
957 
958         assert!(reader.read_obj::<u8>().is_err());
959 
960         assert_eq!(reader.available_bytes(), 0);
961         assert_eq!(reader.bytes_read(), 0);
962     }
963 
964     #[test]
writer_test_incompatible_chain()965     fn writer_test_incompatible_chain() {
966         use DescriptorType::*;
967 
968         let memory_start_addr = GuestAddress(0x0);
969         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
970 
971         let chain = create_descriptor_chain(
972             &memory,
973             GuestAddress(0x0),
974             GuestAddress(0x100),
975             vec![(Readable, 8)],
976             0,
977         )
978         .expect("create_descriptor_chain failed");
979         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
980         assert_eq!(writer.available_bytes(), 0);
981         assert_eq!(writer.bytes_written(), 0);
982 
983         assert!(writer.write_obj(0u8).is_err());
984 
985         assert_eq!(writer.available_bytes(), 0);
986         assert_eq!(writer.bytes_written(), 0);
987     }
988 
989     #[test]
reader_failing_io()990     fn reader_failing_io() {
991         use DescriptorType::*;
992 
993         let memory_start_addr = GuestAddress(0x0);
994         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
995 
996         let chain = create_descriptor_chain(
997             &memory,
998             GuestAddress(0x0),
999             GuestAddress(0x100),
1000             vec![(Readable, 256), (Readable, 256)],
1001             0,
1002         )
1003         .expect("create_descriptor_chain failed");
1004 
1005         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1006 
1007         // Open a file in read-only mode so writes to it to trigger an I/O error.
1008         let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
1009         let mut ro_file = File::open(device_file).expect("failed to open device file");
1010 
1011         reader
1012             .read_exact_to(&mut ro_file, 512)
1013             .expect_err("successfully read more bytes than SharedMemory size");
1014 
1015         // The write above should have failed entirely, so we end up not writing any bytes at all.
1016         assert_eq!(reader.available_bytes(), 512);
1017         assert_eq!(reader.bytes_read(), 0);
1018     }
1019 
1020     #[test]
writer_failing_io()1021     fn writer_failing_io() {
1022         use DescriptorType::*;
1023 
1024         let memory_start_addr = GuestAddress(0x0);
1025         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1026 
1027         let chain = create_descriptor_chain(
1028             &memory,
1029             GuestAddress(0x0),
1030             GuestAddress(0x100),
1031             vec![(Writable, 256), (Writable, 256)],
1032             0,
1033         )
1034         .expect("create_descriptor_chain failed");
1035 
1036         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
1037 
1038         let mut file = tempfile().unwrap();
1039 
1040         file.set_len(384).unwrap();
1041 
1042         writer
1043             .write_all_from(&mut file, 512)
1044             .expect_err("successfully wrote more bytes than in SharedMemory");
1045 
1046         assert_eq!(writer.available_bytes(), 128);
1047         assert_eq!(writer.bytes_written(), 384);
1048     }
1049 
1050     #[test]
reader_writer_shared_chain()1051     fn reader_writer_shared_chain() {
1052         use DescriptorType::*;
1053 
1054         let memory_start_addr = GuestAddress(0x0);
1055         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1056 
1057         let chain = create_descriptor_chain(
1058             &memory,
1059             GuestAddress(0x0),
1060             GuestAddress(0x100),
1061             vec![
1062                 (Readable, 16),
1063                 (Readable, 16),
1064                 (Readable, 96),
1065                 (Writable, 64),
1066                 (Writable, 1),
1067                 (Writable, 3),
1068             ],
1069             0,
1070         )
1071         .expect("create_descriptor_chain failed");
1072         let mut reader =
1073             Reader::new(memory.clone(), chain.clone()).expect("failed to create Reader");
1074         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
1075 
1076         assert_eq!(reader.bytes_read(), 0);
1077         assert_eq!(writer.bytes_written(), 0);
1078 
1079         let mut buffer = Vec::with_capacity(200);
1080 
1081         assert_eq!(
1082             reader
1083                 .read_to_end(&mut buffer)
1084                 .expect("read should not fail here"),
1085             128
1086         );
1087 
1088         // The writable descriptors are only 68 bytes long.
1089         writer
1090             .write_all(&buffer[..68])
1091             .expect("write should not fail here");
1092 
1093         assert_eq!(reader.available_bytes(), 0);
1094         assert_eq!(reader.bytes_read(), 128);
1095         assert_eq!(writer.available_bytes(), 0);
1096         assert_eq!(writer.bytes_written(), 68);
1097     }
1098 
1099     #[test]
reader_writer_shattered_object()1100     fn reader_writer_shattered_object() {
1101         use DescriptorType::*;
1102 
1103         let memory_start_addr = GuestAddress(0x0);
1104         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1105 
1106         let secret: Le32 = 0x12345678.into();
1107 
1108         // Create a descriptor chain with memory regions that are properly separated.
1109         let chain_writer = create_descriptor_chain(
1110             &memory,
1111             GuestAddress(0x0),
1112             GuestAddress(0x100),
1113             vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1114             123,
1115         )
1116         .expect("create_descriptor_chain failed");
1117         let mut writer =
1118             Writer::new(memory.clone(), chain_writer).expect("failed to create Writer");
1119         writer
1120             .write_obj(secret)
1121             .expect("write_obj should not fail here");
1122 
1123         // Now create new descriptor chain pointing to the same memory and try to read it.
1124         let chain_reader = create_descriptor_chain(
1125             &memory,
1126             GuestAddress(0x0),
1127             GuestAddress(0x100),
1128             vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1129             123,
1130         )
1131         .expect("create_descriptor_chain failed");
1132         let mut reader = Reader::new(memory, chain_reader).expect("failed to create Reader");
1133         match reader.read_obj::<Le32>() {
1134             Err(_) => panic!("read_obj should not fail here"),
1135             Ok(read_secret) => assert_eq!(read_secret, secret),
1136         }
1137     }
1138 
1139     #[test]
reader_unexpected_eof()1140     fn reader_unexpected_eof() {
1141         use DescriptorType::*;
1142 
1143         let memory_start_addr = GuestAddress(0x0);
1144         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1145 
1146         let chain = create_descriptor_chain(
1147             &memory,
1148             GuestAddress(0x0),
1149             GuestAddress(0x100),
1150             vec![(Readable, 256), (Readable, 256)],
1151             0,
1152         )
1153         .expect("create_descriptor_chain failed");
1154 
1155         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1156 
1157         let mut buf = vec![0; 1024];
1158 
1159         assert_eq!(
1160             reader
1161                 .read_exact(&mut buf[..])
1162                 .expect_err("read more bytes than available")
1163                 .kind(),
1164             io::ErrorKind::UnexpectedEof
1165         );
1166     }
1167 
1168     #[test]
split_border()1169     fn split_border() {
1170         use DescriptorType::*;
1171 
1172         let memory_start_addr = GuestAddress(0x0);
1173         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1174 
1175         let chain = create_descriptor_chain(
1176             &memory,
1177             GuestAddress(0x0),
1178             GuestAddress(0x100),
1179             vec![
1180                 (Readable, 16),
1181                 (Readable, 16),
1182                 (Readable, 96),
1183                 (Writable, 64),
1184                 (Writable, 1),
1185                 (Writable, 3),
1186             ],
1187             0,
1188         )
1189         .expect("create_descriptor_chain failed");
1190         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1191 
1192         let other = reader.split_at(32);
1193         assert_eq!(reader.available_bytes(), 32);
1194         assert_eq!(other.available_bytes(), 96);
1195     }
1196 
1197     #[test]
split_middle()1198     fn split_middle() {
1199         use DescriptorType::*;
1200 
1201         let memory_start_addr = GuestAddress(0x0);
1202         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1203 
1204         let chain = create_descriptor_chain(
1205             &memory,
1206             GuestAddress(0x0),
1207             GuestAddress(0x100),
1208             vec![
1209                 (Readable, 16),
1210                 (Readable, 16),
1211                 (Readable, 96),
1212                 (Writable, 64),
1213                 (Writable, 1),
1214                 (Writable, 3),
1215             ],
1216             0,
1217         )
1218         .expect("create_descriptor_chain failed");
1219         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1220 
1221         let other = reader.split_at(24);
1222         assert_eq!(reader.available_bytes(), 24);
1223         assert_eq!(other.available_bytes(), 104);
1224     }
1225 
1226     #[test]
split_end()1227     fn split_end() {
1228         use DescriptorType::*;
1229 
1230         let memory_start_addr = GuestAddress(0x0);
1231         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1232 
1233         let chain = create_descriptor_chain(
1234             &memory,
1235             GuestAddress(0x0),
1236             GuestAddress(0x100),
1237             vec![
1238                 (Readable, 16),
1239                 (Readable, 16),
1240                 (Readable, 96),
1241                 (Writable, 64),
1242                 (Writable, 1),
1243                 (Writable, 3),
1244             ],
1245             0,
1246         )
1247         .expect("create_descriptor_chain failed");
1248         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1249 
1250         let other = reader.split_at(128);
1251         assert_eq!(reader.available_bytes(), 128);
1252         assert_eq!(other.available_bytes(), 0);
1253     }
1254 
1255     #[test]
split_beginning()1256     fn split_beginning() {
1257         use DescriptorType::*;
1258 
1259         let memory_start_addr = GuestAddress(0x0);
1260         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1261 
1262         let chain = create_descriptor_chain(
1263             &memory,
1264             GuestAddress(0x0),
1265             GuestAddress(0x100),
1266             vec![
1267                 (Readable, 16),
1268                 (Readable, 16),
1269                 (Readable, 96),
1270                 (Writable, 64),
1271                 (Writable, 1),
1272                 (Writable, 3),
1273             ],
1274             0,
1275         )
1276         .expect("create_descriptor_chain failed");
1277         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1278 
1279         let other = reader.split_at(0);
1280         assert_eq!(reader.available_bytes(), 0);
1281         assert_eq!(other.available_bytes(), 128);
1282     }
1283 
1284     #[test]
split_outofbounds()1285     fn split_outofbounds() {
1286         use DescriptorType::*;
1287 
1288         let memory_start_addr = GuestAddress(0x0);
1289         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1290 
1291         let chain = create_descriptor_chain(
1292             &memory,
1293             GuestAddress(0x0),
1294             GuestAddress(0x100),
1295             vec![
1296                 (Readable, 16),
1297                 (Readable, 16),
1298                 (Readable, 96),
1299                 (Writable, 64),
1300                 (Writable, 1),
1301                 (Writable, 3),
1302             ],
1303             0,
1304         )
1305         .expect("create_descriptor_chain failed");
1306         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1307 
1308         let other = reader.split_at(256);
1309         assert_eq!(
1310             other.available_bytes(),
1311             0,
1312             "Reader returned from out-of-bounds split still has available bytes"
1313         );
1314     }
1315 
1316     #[test]
read_full()1317     fn read_full() {
1318         use DescriptorType::*;
1319 
1320         let memory_start_addr = GuestAddress(0x0);
1321         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1322 
1323         let chain = create_descriptor_chain(
1324             &memory,
1325             GuestAddress(0x0),
1326             GuestAddress(0x100),
1327             vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1328             0,
1329         )
1330         .expect("create_descriptor_chain failed");
1331         let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
1332 
1333         let mut buf = vec![0u8; 64];
1334         assert_eq!(
1335             reader.read(&mut buf[..]).expect("failed to read to buffer"),
1336             48
1337         );
1338     }
1339 
1340     #[test]
write_full()1341     fn write_full() {
1342         use DescriptorType::*;
1343 
1344         let memory_start_addr = GuestAddress(0x0);
1345         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1346 
1347         let chain = create_descriptor_chain(
1348             &memory,
1349             GuestAddress(0x0),
1350             GuestAddress(0x100),
1351             vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1352             0,
1353         )
1354         .expect("create_descriptor_chain failed");
1355         let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
1356 
1357         let buf = vec![0xdeu8; 64];
1358         assert_eq!(
1359             writer.write(&buf[..]).expect("failed to write from buffer"),
1360             48
1361         );
1362     }
1363 
1364     #[test]
consume_collect()1365     fn consume_collect() {
1366         use DescriptorType::*;
1367 
1368         let memory_start_addr = GuestAddress(0x0);
1369         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1370         let vs: Vec<Le64> = vec![
1371             0x0101010101010101.into(),
1372             0x0202020202020202.into(),
1373             0x0303030303030303.into(),
1374         ];
1375 
1376         let write_chain = create_descriptor_chain(
1377             &memory,
1378             GuestAddress(0x0),
1379             GuestAddress(0x100),
1380             vec![(Writable, 24)],
1381             0,
1382         )
1383         .expect("create_descriptor_chain failed");
1384         let mut writer = Writer::new(memory.clone(), write_chain).expect("failed to create Writer");
1385         writer
1386             .consume(vs.clone())
1387             .expect("failed to consume() a vector");
1388 
1389         let read_chain = create_descriptor_chain(
1390             &memory,
1391             GuestAddress(0x0),
1392             GuestAddress(0x100),
1393             vec![(Readable, 24)],
1394             0,
1395         )
1396         .expect("create_descriptor_chain failed");
1397         let mut reader = Reader::new(memory, read_chain).expect("failed to create Reader");
1398         let vs_read = reader
1399             .collect::<io::Result<Vec<Le64>>, _>()
1400             .expect("failed to collect() values");
1401         assert_eq!(vs, vs_read);
1402     }
1403 
1404     #[test]
get_remaining_region_with_count()1405     fn get_remaining_region_with_count() {
1406         use DescriptorType::*;
1407 
1408         let memory_start_addr = GuestAddress(0x0);
1409         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1410 
1411         let chain = create_descriptor_chain(
1412             &memory,
1413             GuestAddress(0x0),
1414             GuestAddress(0x100),
1415             vec![
1416                 (Readable, 16),
1417                 (Readable, 16),
1418                 (Readable, 96),
1419                 (Writable, 64),
1420                 (Writable, 1),
1421                 (Writable, 3),
1422             ],
1423             0,
1424         )
1425         .expect("create_descriptor_chain failed");
1426 
1427         let Reader {
1428             mem: _,
1429             mut regions,
1430         } = Reader::new(memory, chain).expect("failed to create Reader");
1431 
1432         let drain = regions
1433             .get_remaining_regions_with_count(::std::usize::MAX)
1434             .iter()
1435             .fold(0usize, |total, region| total + region.len);
1436         assert_eq!(drain, 128);
1437 
1438         let exact = regions
1439             .get_remaining_regions_with_count(32)
1440             .iter()
1441             .fold(0usize, |total, region| total + region.len);
1442         assert!(exact > 0);
1443         assert!(exact <= 32);
1444 
1445         let split = regions
1446             .get_remaining_regions_with_count(24)
1447             .iter()
1448             .fold(0usize, |total, region| total + region.len);
1449         assert!(split > 0);
1450         assert!(split <= 24);
1451 
1452         regions.consume(64);
1453 
1454         let first = regions
1455             .get_remaining_regions_with_count(8)
1456             .iter()
1457             .fold(0usize, |total, region| total + region.len);
1458         assert!(first > 0);
1459         assert!(first <= 8);
1460     }
1461 
1462     #[test]
get_remaining_with_count()1463     fn get_remaining_with_count() {
1464         use DescriptorType::*;
1465 
1466         let memory_start_addr = GuestAddress(0x0);
1467         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1468 
1469         let chain = create_descriptor_chain(
1470             &memory,
1471             GuestAddress(0x0),
1472             GuestAddress(0x100),
1473             vec![
1474                 (Readable, 16),
1475                 (Readable, 16),
1476                 (Readable, 96),
1477                 (Writable, 64),
1478                 (Writable, 1),
1479                 (Writable, 3),
1480             ],
1481             0,
1482         )
1483         .expect("create_descriptor_chain failed");
1484         let Reader {
1485             mem: _,
1486             mut regions,
1487         } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1488 
1489         let drain = regions
1490             .get_remaining_with_count(&memory, ::std::usize::MAX)
1491             .iter()
1492             .fold(0usize, |total, iov| total + iov.size());
1493         assert_eq!(drain, 128);
1494 
1495         let exact = regions
1496             .get_remaining_with_count(&memory, 32)
1497             .iter()
1498             .fold(0usize, |total, iov| total + iov.size());
1499         assert!(exact > 0);
1500         assert!(exact <= 32);
1501 
1502         let split = regions
1503             .get_remaining_with_count(&memory, 24)
1504             .iter()
1505             .fold(0usize, |total, iov| total + iov.size());
1506         assert!(split > 0);
1507         assert!(split <= 24);
1508 
1509         regions.consume(64);
1510 
1511         let first = regions
1512             .get_remaining_with_count(&memory, 8)
1513             .iter()
1514             .fold(0usize, |total, iov| total + iov.size());
1515         assert!(first > 0);
1516         assert!(first <= 8);
1517     }
1518 }
1519