• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp;
6 use std::io;
7 use std::io::Write;
8 use std::iter::FromIterator;
9 use std::marker::PhantomData;
10 use std::mem::size_of;
11 use std::mem::MaybeUninit;
12 use std::ptr::copy_nonoverlapping;
13 use std::sync::Arc;
14 
15 use anyhow::Context;
16 use base::FileReadWriteAtVolatile;
17 use base::FileReadWriteVolatile;
18 use base::VolatileSlice;
19 use cros_async::MemRegion;
20 use cros_async::MemRegionIter;
21 use data_model::Le16;
22 use data_model::Le32;
23 use data_model::Le64;
24 use disk::AsyncDisk;
25 use smallvec::SmallVec;
26 use vm_memory::GuestAddress;
27 use vm_memory::GuestMemory;
28 use zerocopy::FromBytes;
29 use zerocopy::Immutable;
30 use zerocopy::IntoBytes;
31 use zerocopy::KnownLayout;
32 
33 use super::DescriptorChain;
34 use crate::virtio::SplitDescriptorChain;
35 
36 struct DescriptorChainRegions {
37     regions: SmallVec<[MemRegion; 2]>,
38 
39     // Index of the current region in `regions`.
40     current_region_index: usize,
41 
42     // Number of bytes consumed in the current region.
43     current_region_offset: usize,
44 
45     // Total bytes consumed in the entire descriptor chain.
46     bytes_consumed: usize,
47 }
48 
49 impl DescriptorChainRegions {
new(regions: SmallVec<[MemRegion; 2]>) -> Self50     fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
51         DescriptorChainRegions {
52             regions,
53             current_region_index: 0,
54             current_region_offset: 0,
55             bytes_consumed: 0,
56         }
57     }
58 
available_bytes(&self) -> usize59     fn available_bytes(&self) -> usize {
60         // This is guaranteed not to overflow because the total length of the chain is checked
61         // during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
62         self.get_remaining_regions()
63             .fold(0usize, |count, region| count + region.len)
64     }
65 
bytes_consumed(&self) -> usize66     fn bytes_consumed(&self) -> usize {
67         self.bytes_consumed
68     }
69 
70     /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
71     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
72     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
73     /// to `consume` will return the same data.
get_remaining_regions(&self) -> MemRegionIter74     fn get_remaining_regions(&self) -> MemRegionIter {
75         MemRegionIter::new(&self.regions[self.current_region_index..])
76             .skip_bytes(self.current_region_offset)
77     }
78 
79     /// Like `get_remaining_regions` but guarantees that the combined length of all the returned
80     /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
81     /// than `count` but will always be greater than 0 as long as there is still space left in the
82     /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter83     fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
84         MemRegionIter::new(&self.regions[self.current_region_index..])
85             .skip_bytes(self.current_region_offset)
86             .take_bytes(count)
87     }
88 
89     /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
90     /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
91     /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
92     /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>93     fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
94         self.get_remaining_regions()
95             .filter_map(|region| {
96                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
97                     .ok()
98             })
99             .collect()
100     }
101 
102     /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
103     /// the 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>104     fn get_remaining_with_count<'mem>(
105         &self,
106         mem: &'mem GuestMemory,
107         count: usize,
108     ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
109         self.get_remaining_regions_with_count(count)
110             .filter_map(|region| {
111                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
112                     .ok()
113             })
114             .collect()
115     }
116 
117     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
118     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)119     fn consume(&mut self, mut count: usize) {
120         while let Some(region) = self.regions.get(self.current_region_index) {
121             let region_remaining = region.len - self.current_region_offset;
122             if count < region_remaining {
123                 // The remaining count to consume is less than the remaining un-consumed length of
124                 // the current region. Adjust the region offset without advancing to the next region
125                 // and stop.
126                 self.current_region_offset += count;
127                 self.bytes_consumed += count;
128                 return;
129             }
130 
131             // The current region has been exhausted. Advance to the next region.
132             self.current_region_index += 1;
133             self.current_region_offset = 0;
134 
135             self.bytes_consumed += region_remaining;
136             count -= region_remaining;
137         }
138     }
139 
split_at(&mut self, offset: usize) -> DescriptorChainRegions140     fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
141         let mut other = DescriptorChainRegions {
142             regions: self.regions.clone(),
143             current_region_index: self.current_region_index,
144             current_region_offset: self.current_region_offset,
145             bytes_consumed: self.bytes_consumed,
146         };
147         other.consume(offset);
148         other.bytes_consumed = 0;
149 
150         let mut rem = offset;
151         let mut end = self.current_region_index;
152         for region in &mut self.regions[self.current_region_index..] {
153             if rem <= region.len {
154                 region.len = rem;
155                 break;
156             }
157 
158             end += 1;
159             rem -= region.len;
160         }
161 
162         self.regions.truncate(end + 1);
163 
164         other
165     }
166 }
167 
168 /// Provides high-level interface over the sequence of memory regions
169 /// defined by readable descriptors in the descriptor chain.
170 ///
171 /// Note that virtio spec requires driver to place any device-writable
172 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
173 /// Reader will skip iterating over descriptor chain when first writable
174 /// descriptor is encountered.
175 pub struct Reader {
176     mem: GuestMemory,
177     regions: DescriptorChainRegions,
178 }
179 
180 /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
181 pub struct ReaderIterator<'a, T: FromBytes> {
182     reader: &'a mut Reader,
183     phantom: PhantomData<T>,
184 }
185 
186 impl<T: FromBytes> Iterator for ReaderIterator<'_, T> {
187     type Item = io::Result<T>;
188 
next(&mut self) -> Option<io::Result<T>>189     fn next(&mut self) -> Option<io::Result<T>> {
190         if self.reader.available_bytes() == 0 {
191             None
192         } else {
193             Some(self.reader.read_obj())
194         }
195     }
196 }
197 
198 impl Reader {
199     /// Construct a new Reader wrapper over `readable_regions`.
new_from_regions( mem: &GuestMemory, readable_regions: SmallVec<[MemRegion; 2]>, ) -> Reader200     pub fn new_from_regions(
201         mem: &GuestMemory,
202         readable_regions: SmallVec<[MemRegion; 2]>,
203     ) -> Reader {
204         Reader {
205             mem: mem.clone(),
206             regions: DescriptorChainRegions::new(readable_regions),
207         }
208     }
209 
210     /// Reads an object from the descriptor chain buffer without consuming it.
peek_obj<T: FromBytes>(&self) -> io::Result<T>211     pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
212         let mut obj = MaybeUninit::uninit();
213 
214         // SAFETY: We pass a valid pointer and size of `obj`.
215         let copied = unsafe {
216             copy_regions_to_mut_ptr(
217                 &self.mem,
218                 self.get_remaining_regions(),
219                 obj.as_mut_ptr() as *mut u8,
220                 size_of::<T>(),
221             )?
222         };
223         if copied != size_of::<T>() {
224             return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
225         }
226 
227         // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
228         // we initialized all bytes in `obj` in the copy above.
229         Ok(unsafe { obj.assume_init() })
230     }
231 
232     /// Reads and consumes an object from the descriptor chain buffer.
read_obj<T: FromBytes>(&mut self) -> io::Result<T>233     pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
234         let obj = self.peek_obj::<T>()?;
235         self.consume(size_of::<T>());
236         Ok(obj)
237     }
238 
239     /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
240     /// them as a collection. Returns an error if the size of the remaining data is indivisible by
241     /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C242     pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
243         self.iter().collect()
244     }
245 
246     /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
247     /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
248     /// doesn't require the objects to be stored in a separate collection.
iter<T: FromBytes>(&mut self) -> ReaderIterator<T>249     pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
250         ReaderIterator {
251             reader: self,
252             phantom: PhantomData,
253         }
254     }
255 
256     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
257     /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize258     pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
259         let mut read = 0usize;
260         let mut dst = slice;
261         for src in self.get_remaining() {
262             src.copy_to_volatile_slice(dst);
263             let copied = std::cmp::min(src.size(), dst.size());
264             read += copied;
265             dst = match dst.offset(copied) {
266                 Ok(v) => v,
267                 Err(_) => break, // The slice is fully consumed
268             };
269         }
270         self.regions.consume(read);
271         read
272     }
273 
274     /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
275     /// `cb`.
read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>( &mut self, cb: C, count: usize, ) -> usize276     pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
277         &mut self,
278         cb: C,
279         count: usize,
280     ) -> usize {
281         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
282         let written = cb(&iovs[..]);
283         self.regions.consume(written);
284         written
285     }
286 
287     /// Reads data from the descriptor chain buffer into a writable object.
288     /// Returns the number of bytes read from the descriptor chain buffer.
289     /// The number of bytes read can be less than `count` if there isn't
290     /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>291     pub fn read_to<F: FileReadWriteVolatile>(
292         &mut self,
293         mut dst: F,
294         count: usize,
295     ) -> io::Result<usize> {
296         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
297         let written = dst.write_vectored_volatile(&iovs[..])?;
298         self.regions.consume(written);
299         Ok(written)
300     }
301 
302     /// Reads data from the descriptor chain buffer into a File at offset `off`.
303     /// Returns the number of bytes read from the descriptor chain buffer.
304     /// The number of bytes read can be less than `count` if there isn't
305     /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, count: usize, off: u64, ) -> io::Result<usize>306     pub fn read_to_at<F: FileReadWriteAtVolatile>(
307         &mut self,
308         dst: &F,
309         count: usize,
310         off: u64,
311     ) -> io::Result<usize> {
312         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
313         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
314         self.regions.consume(written);
315         Ok(written)
316     }
317 
318     /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
319     /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>320     pub fn read_exact_to<F: FileReadWriteVolatile>(
321         &mut self,
322         mut dst: F,
323         mut count: usize,
324     ) -> io::Result<()> {
325         while count > 0 {
326             match self.read_to(&mut dst, count) {
327                 Ok(0) => {
328                     return Err(io::Error::new(
329                         io::ErrorKind::UnexpectedEof,
330                         "failed to fill whole buffer",
331                     ))
332                 }
333                 Ok(n) => count -= n,
334                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
335                 Err(e) => return Err(e),
336             }
337         }
338 
339         Ok(())
340     }
341 
342     /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
343     /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> io::Result<()>344     pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
345         &mut self,
346         dst: &F,
347         mut count: usize,
348         mut off: u64,
349     ) -> io::Result<()> {
350         while count > 0 {
351             match self.read_to_at(dst, count, off) {
352                 Ok(0) => {
353                     return Err(io::Error::new(
354                         io::ErrorKind::UnexpectedEof,
355                         "failed to fill whole buffer",
356                     ))
357                 }
358                 Ok(n) => {
359                     count -= n;
360                     off += n as u64;
361                 }
362                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
363                 Err(e) => return Err(e),
364             }
365         }
366 
367         Ok(())
368     }
369 
370     /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
371     /// Returns the number of bytes read from the descriptor chain buffer.
372     /// The number of bytes read can be less than `count` if there isn't
373     /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>374     pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
375         &mut self,
376         dst: &F,
377         count: usize,
378         off: u64,
379     ) -> disk::Result<usize> {
380         let written = dst
381             .write_from_mem(
382                 off,
383                 Arc::new(self.mem.clone()),
384                 self.regions.get_remaining_regions_with_count(count),
385             )
386             .await?;
387         self.regions.consume(written);
388         Ok(written)
389     }
390 
391     /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
392     /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>393     pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
394         &mut self,
395         dst: &F,
396         mut count: usize,
397         mut off: u64,
398     ) -> disk::Result<()> {
399         while count > 0 {
400             let nread = self.read_to_at_fut(dst, count, off).await?;
401             if nread == 0 {
402                 return Err(disk::Error::ReadingData(io::Error::new(
403                     io::ErrorKind::UnexpectedEof,
404                     "failed to write whole buffer",
405                 )));
406             }
407             count -= nread;
408             off += nread as u64;
409         }
410 
411         Ok(())
412     }
413 
414     /// Returns number of bytes available for reading.  May return an error if the combined
415     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize416     pub fn available_bytes(&self) -> usize {
417         self.regions.available_bytes()
418     }
419 
420     /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize421     pub fn bytes_read(&self) -> usize {
422         self.regions.bytes_consumed()
423     }
424 
get_remaining_regions(&self) -> MemRegionIter425     pub fn get_remaining_regions(&self) -> MemRegionIter {
426         self.regions.get_remaining_regions()
427     }
428 
429     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
430     /// Calling this method does not actually consume any data from the `Reader` and callers should
431     /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>432     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
433         self.regions.get_remaining(&self.mem)
434     }
435 
436     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
437     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)438     pub fn consume(&mut self, amt: usize) {
439         self.regions.consume(amt)
440     }
441 
442     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
443     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
444     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
445     /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader446     pub fn split_at(&mut self, offset: usize) -> Reader {
447         Reader {
448             mem: self.mem.clone(),
449             regions: self.regions.split_at(offset),
450         }
451     }
452 }
453 
454 /// Copy up to `size` bytes from `src` into `dst`.
455 ///
456 /// Returns the total number of bytes copied.
457 ///
458 /// # Safety
459 ///
460 /// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
461 ///
462 /// After the function returns, it is only safe to assume that the number of bytes indicated by the
463 /// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
464 /// that point are not initialized by this function.
copy_regions_to_mut_ptr( mem: &GuestMemory, src: MemRegionIter, dst: *mut u8, size: usize, ) -> io::Result<usize>465 unsafe fn copy_regions_to_mut_ptr(
466     mem: &GuestMemory,
467     src: MemRegionIter,
468     dst: *mut u8,
469     size: usize,
470 ) -> io::Result<usize> {
471     let mut copied = 0;
472     for src_region in src {
473         if copied >= size {
474             break;
475         }
476 
477         let remaining = size - copied;
478         let count = cmp::min(remaining, src_region.len);
479 
480         let vslice = mem
481             .get_slice_at_addr(GuestAddress(src_region.offset), count)
482             .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
483 
484         // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
485         // the `count` calculation ensures we will write at most `size` bytes into `dst`.
486         unsafe {
487             copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
488         }
489 
490         copied += count;
491     }
492 
493     Ok(copied)
494 }
495 
496 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>497     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
498         // SAFETY: We pass a valid pointer and size combination derived from `buf`.
499         let total = unsafe {
500             copy_regions_to_mut_ptr(
501                 &self.mem,
502                 self.regions.get_remaining_regions(),
503                 buf.as_mut_ptr(),
504                 buf.len(),
505             )?
506         };
507         self.regions.consume(total);
508         Ok(total)
509     }
510 }
511 
512 /// Provides high-level interface over the sequence of memory regions
513 /// defined by writable descriptors in the descriptor chain.
514 ///
515 /// Note that virtio spec requires driver to place any device-writable
516 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
517 /// Writer will start iterating the descriptors from the first writable one and will
518 /// assume that all following descriptors are writable.
519 pub struct Writer {
520     mem: GuestMemory,
521     regions: DescriptorChainRegions,
522 }
523 
524 impl Writer {
525     /// Construct a new Writer wrapper over `writable_regions`.
new_from_regions( mem: &GuestMemory, writable_regions: SmallVec<[MemRegion; 2]>, ) -> Writer526     pub fn new_from_regions(
527         mem: &GuestMemory,
528         writable_regions: SmallVec<[MemRegion; 2]>,
529     ) -> Writer {
530         Writer {
531             mem: mem.clone(),
532             regions: DescriptorChainRegions::new(writable_regions),
533         }
534     }
535 
536     /// Writes an object to the descriptor chain buffer.
write_obj<T: Immutable + IntoBytes>(&mut self, val: T) -> io::Result<()>537     pub fn write_obj<T: Immutable + IntoBytes>(&mut self, val: T) -> io::Result<()> {
538         self.write_all(val.as_bytes())
539     }
540 
541     /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
542     /// this doesn't require the values to be stored in an intermediate collection first. It also
543     /// allows callers to choose which elements in a collection to write, for example by using the
544     /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: Immutable + IntoBytes, I: Iterator<Item = T>>( &mut self, mut iter: I, ) -> io::Result<()>545     pub fn write_iter<T: Immutable + IntoBytes, I: Iterator<Item = T>>(
546         &mut self,
547         mut iter: I,
548     ) -> io::Result<()> {
549         iter.try_for_each(|v| self.write_obj(v))
550     }
551 
552     /// Writes a collection of objects into the descriptor chain buffer.
consume<T: Immutable + IntoBytes, C: IntoIterator<Item = T>>( &mut self, vals: C, ) -> io::Result<()>553     pub fn consume<T: Immutable + IntoBytes, C: IntoIterator<Item = T>>(
554         &mut self,
555         vals: C,
556     ) -> io::Result<()> {
557         self.write_iter(vals.into_iter())
558     }
559 
560     /// Returns number of bytes available for writing.  May return an error if the combined
561     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize562     pub fn available_bytes(&self) -> usize {
563         self.regions.available_bytes()
564     }
565 
566     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
567     /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize568     pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
569         let mut written = 0usize;
570         let mut src = slice;
571         for dst in self.get_remaining() {
572             src.copy_to_volatile_slice(dst);
573             let copied = std::cmp::min(src.size(), dst.size());
574             written += copied;
575             src = match src.offset(copied) {
576                 Ok(v) => v,
577                 Err(_) => break, // The slice is fully consumed
578             };
579         }
580         self.regions.consume(written);
581         written
582     }
583 
584     /// Writes data to the descriptor chain buffer from a readable object.
585     /// Returns the number of bytes written to the descriptor chain buffer.
586     /// The number of bytes written can be less than `count` if
587     /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>588     pub fn write_from<F: FileReadWriteVolatile>(
589         &mut self,
590         mut src: F,
591         count: usize,
592     ) -> io::Result<usize> {
593         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
594         let read = src.read_vectored_volatile(&iovs[..])?;
595         self.regions.consume(read);
596         Ok(read)
597     }
598 
599     /// Writes data to the descriptor chain buffer from a File at offset `off`.
600     /// Returns the number of bytes written to the descriptor chain buffer.
601     /// The number of bytes written can be less than `count` if
602     /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, count: usize, off: u64, ) -> io::Result<usize>603     pub fn write_from_at<F: FileReadWriteAtVolatile>(
604         &mut self,
605         src: &F,
606         count: usize,
607         off: u64,
608     ) -> io::Result<usize> {
609         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
610         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
611         self.regions.consume(read);
612         Ok(read)
613     }
614 
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>615     pub fn write_all_from<F: FileReadWriteVolatile>(
616         &mut self,
617         mut src: F,
618         mut count: usize,
619     ) -> io::Result<()> {
620         while count > 0 {
621             match self.write_from(&mut src, count) {
622                 Ok(0) => {
623                     return Err(io::Error::new(
624                         io::ErrorKind::WriteZero,
625                         "failed to write whole buffer",
626                     ))
627                 }
628                 Ok(n) => count -= n,
629                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
630                 Err(e) => return Err(e),
631             }
632         }
633 
634         Ok(())
635     }
636 
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> io::Result<()>637     pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
638         &mut self,
639         src: &F,
640         mut count: usize,
641         mut off: u64,
642     ) -> io::Result<()> {
643         while count > 0 {
644             match self.write_from_at(src, count, off) {
645                 Ok(0) => {
646                     return Err(io::Error::new(
647                         io::ErrorKind::WriteZero,
648                         "failed to write whole buffer",
649                     ))
650                 }
651                 Ok(n) => {
652                     count -= n;
653                     off += n as u64;
654                 }
655                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
656                 Err(e) => return Err(e),
657             }
658         }
659         Ok(())
660     }
661     /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
662     /// Returns the number of bytes written to the descriptor chain buffer.
663     /// The number of bytes written can be less than `count` if
664     /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>665     pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
666         &mut self,
667         src: &F,
668         count: usize,
669         off: u64,
670     ) -> disk::Result<usize> {
671         let read = src
672             .read_to_mem(
673                 off,
674                 Arc::new(self.mem.clone()),
675                 self.regions.get_remaining_regions_with_count(count),
676             )
677             .await?;
678         self.regions.consume(read);
679         Ok(read)
680     }
681 
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>682     pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
683         &mut self,
684         src: &F,
685         mut count: usize,
686         mut off: u64,
687     ) -> disk::Result<()> {
688         while count > 0 {
689             let nwritten = self.write_from_at_fut(src, count, off).await?;
690             if nwritten == 0 {
691                 return Err(disk::Error::WritingData(io::Error::new(
692                     io::ErrorKind::UnexpectedEof,
693                     "failed to write whole buffer",
694                 )));
695             }
696             count -= nwritten;
697             off += nwritten as u64;
698         }
699         Ok(())
700     }
701 
702     /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize703     pub fn bytes_written(&self) -> usize {
704         self.regions.bytes_consumed()
705     }
706 
get_remaining_regions(&self) -> MemRegionIter707     pub fn get_remaining_regions(&self) -> MemRegionIter {
708         self.regions.get_remaining_regions()
709     }
710 
711     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
712     /// Calling this method does not actually advance the current position of the `Writer` in the
713     /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
714     /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
715     /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>716     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
717         self.regions.get_remaining(&self.mem)
718     }
719 
720     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
721     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)722     pub fn consume_bytes(&mut self, amt: usize) {
723         self.regions.consume(amt)
724     }
725 
726     /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
727     /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
728     /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
729     /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer730     pub fn split_at(&mut self, offset: usize) -> Writer {
731         Writer {
732             mem: self.mem.clone(),
733             regions: self.regions.split_at(offset),
734         }
735     }
736 }
737 
738 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>739     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
740         let mut rem = buf;
741         let mut total = 0;
742         for b in self.regions.get_remaining(&self.mem) {
743             if rem.is_empty() {
744                 break;
745             }
746 
747             let count = cmp::min(rem.len(), b.size());
748             // SAFETY:
749             // Safe because we have already verified that `vs` points to valid memory.
750             unsafe {
751                 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
752             }
753             rem = &rem[count..];
754             total += count;
755         }
756 
757         self.regions.consume(total);
758         Ok(total)
759     }
760 
flush(&mut self) -> io::Result<()>761     fn flush(&mut self) -> io::Result<()> {
762         // Nothing to flush since the writes go straight into the buffer.
763         Ok(())
764     }
765 }
766 
767 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
768 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
769 
770 #[derive(Copy, Clone, PartialEq, Eq)]
771 pub enum DescriptorType {
772     Readable,
773     Writable,
774 }
775 
776 #[derive(Copy, Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
777 #[repr(C)]
778 struct virtq_desc {
779     addr: Le64,
780     len: Le32,
781     flags: Le16,
782     next: Le16,
783 }
784 
785 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> anyhow::Result<DescriptorChain>786 pub fn create_descriptor_chain(
787     memory: &GuestMemory,
788     descriptor_array_addr: GuestAddress,
789     mut buffers_start_addr: GuestAddress,
790     descriptors: Vec<(DescriptorType, u32)>,
791     spaces_between_regions: u32,
792 ) -> anyhow::Result<DescriptorChain> {
793     let descriptors_len = descriptors.len();
794     for (index, (type_, size)) in descriptors.into_iter().enumerate() {
795         let mut flags = 0;
796         if let DescriptorType::Writable = type_ {
797             flags |= VIRTQ_DESC_F_WRITE;
798         }
799         if index + 1 < descriptors_len {
800             flags |= VIRTQ_DESC_F_NEXT;
801         }
802 
803         let index = index as u16;
804         let desc = virtq_desc {
805             addr: buffers_start_addr.offset().into(),
806             len: size.into(),
807             flags: flags.into(),
808             next: (index + 1).into(),
809         };
810 
811         let offset = size + spaces_between_regions;
812         buffers_start_addr = buffers_start_addr
813             .checked_add(offset as u64)
814             .context("Invalid buffers_start_addr)")?;
815 
816         let _ = memory.write_obj_at_addr(
817             desc,
818             descriptor_array_addr
819                 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
820                 .context("Invalid descriptor_array_addr")?,
821         );
822     }
823 
824     let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
825     DescriptorChain::new(chain, memory, 0)
826 }
827 
828 #[cfg(test)]
829 mod tests {
830     use std::fs::File;
831     use std::io::Read;
832 
833     use cros_async::Executor;
834     use tempfile::tempfile;
835     use tempfile::NamedTempFile;
836 
837     use super::*;
838 
839     #[test]
reader_test_simple_chain()840     fn reader_test_simple_chain() {
841         use DescriptorType::*;
842 
843         let memory_start_addr = GuestAddress(0x0);
844         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
845 
846         let mut chain = create_descriptor_chain(
847             &memory,
848             GuestAddress(0x0),
849             GuestAddress(0x100),
850             vec![
851                 (Readable, 8),
852                 (Readable, 16),
853                 (Readable, 18),
854                 (Readable, 64),
855             ],
856             0,
857         )
858         .expect("create_descriptor_chain failed");
859         let reader = &mut chain.reader;
860         assert_eq!(reader.available_bytes(), 106);
861         assert_eq!(reader.bytes_read(), 0);
862 
863         let mut buffer = [0u8; 64];
864         reader
865             .read_exact(&mut buffer)
866             .expect("read_exact should not fail here");
867 
868         assert_eq!(reader.available_bytes(), 42);
869         assert_eq!(reader.bytes_read(), 64);
870 
871         match reader.read(&mut buffer) {
872             Err(_) => panic!("read should not fail here"),
873             Ok(length) => assert_eq!(length, 42),
874         }
875 
876         assert_eq!(reader.available_bytes(), 0);
877         assert_eq!(reader.bytes_read(), 106);
878     }
879 
880     #[test]
writer_test_simple_chain()881     fn writer_test_simple_chain() {
882         use DescriptorType::*;
883 
884         let memory_start_addr = GuestAddress(0x0);
885         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
886 
887         let mut chain = create_descriptor_chain(
888             &memory,
889             GuestAddress(0x0),
890             GuestAddress(0x100),
891             vec![
892                 (Writable, 8),
893                 (Writable, 16),
894                 (Writable, 18),
895                 (Writable, 64),
896             ],
897             0,
898         )
899         .expect("create_descriptor_chain failed");
900         let writer = &mut chain.writer;
901         assert_eq!(writer.available_bytes(), 106);
902         assert_eq!(writer.bytes_written(), 0);
903 
904         let buffer = [0; 64];
905         writer
906             .write_all(&buffer)
907             .expect("write_all should not fail here");
908 
909         assert_eq!(writer.available_bytes(), 42);
910         assert_eq!(writer.bytes_written(), 64);
911 
912         match writer.write(&buffer) {
913             Err(_) => panic!("write should not fail here"),
914             Ok(length) => assert_eq!(length, 42),
915         }
916 
917         assert_eq!(writer.available_bytes(), 0);
918         assert_eq!(writer.bytes_written(), 106);
919     }
920 
921     #[test]
reader_test_incompatible_chain()922     fn reader_test_incompatible_chain() {
923         use DescriptorType::*;
924 
925         let memory_start_addr = GuestAddress(0x0);
926         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
927 
928         let mut chain = create_descriptor_chain(
929             &memory,
930             GuestAddress(0x0),
931             GuestAddress(0x100),
932             vec![(Writable, 8)],
933             0,
934         )
935         .expect("create_descriptor_chain failed");
936         let reader = &mut chain.reader;
937         assert_eq!(reader.available_bytes(), 0);
938         assert_eq!(reader.bytes_read(), 0);
939 
940         assert!(reader.read_obj::<u8>().is_err());
941 
942         assert_eq!(reader.available_bytes(), 0);
943         assert_eq!(reader.bytes_read(), 0);
944     }
945 
946     #[test]
writer_test_incompatible_chain()947     fn writer_test_incompatible_chain() {
948         use DescriptorType::*;
949 
950         let memory_start_addr = GuestAddress(0x0);
951         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
952 
953         let mut chain = create_descriptor_chain(
954             &memory,
955             GuestAddress(0x0),
956             GuestAddress(0x100),
957             vec![(Readable, 8)],
958             0,
959         )
960         .expect("create_descriptor_chain failed");
961         let writer = &mut chain.writer;
962         assert_eq!(writer.available_bytes(), 0);
963         assert_eq!(writer.bytes_written(), 0);
964 
965         assert!(writer.write_obj(0u8).is_err());
966 
967         assert_eq!(writer.available_bytes(), 0);
968         assert_eq!(writer.bytes_written(), 0);
969     }
970 
971     #[test]
reader_failing_io()972     fn reader_failing_io() {
973         use DescriptorType::*;
974 
975         let memory_start_addr = GuestAddress(0x0);
976         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
977 
978         let mut chain = create_descriptor_chain(
979             &memory,
980             GuestAddress(0x0),
981             GuestAddress(0x100),
982             vec![(Readable, 256), (Readable, 256)],
983             0,
984         )
985         .expect("create_descriptor_chain failed");
986 
987         let reader = &mut chain.reader;
988 
989         // Open a file in read-only mode so writes to it to trigger an I/O error.
990         let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
991         let mut ro_file = File::open(device_file).expect("failed to open device file");
992 
993         reader
994             .read_exact_to(&mut ro_file, 512)
995             .expect_err("successfully read more bytes than SharedMemory size");
996 
997         // The write above should have failed entirely, so we end up not writing any bytes at all.
998         assert_eq!(reader.available_bytes(), 512);
999         assert_eq!(reader.bytes_read(), 0);
1000     }
1001 
1002     #[test]
writer_failing_io()1003     fn writer_failing_io() {
1004         use DescriptorType::*;
1005 
1006         let memory_start_addr = GuestAddress(0x0);
1007         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1008 
1009         let mut chain = create_descriptor_chain(
1010             &memory,
1011             GuestAddress(0x0),
1012             GuestAddress(0x100),
1013             vec![(Writable, 256), (Writable, 256)],
1014             0,
1015         )
1016         .expect("create_descriptor_chain failed");
1017 
1018         let writer = &mut chain.writer;
1019 
1020         let mut file = tempfile().unwrap();
1021 
1022         file.set_len(384).unwrap();
1023 
1024         writer
1025             .write_all_from(&mut file, 512)
1026             .expect_err("successfully wrote more bytes than in SharedMemory");
1027 
1028         assert_eq!(writer.available_bytes(), 128);
1029         assert_eq!(writer.bytes_written(), 384);
1030     }
1031 
1032     #[test]
reader_writer_shared_chain()1033     fn reader_writer_shared_chain() {
1034         use DescriptorType::*;
1035 
1036         let memory_start_addr = GuestAddress(0x0);
1037         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1038 
1039         let mut chain = create_descriptor_chain(
1040             &memory,
1041             GuestAddress(0x0),
1042             GuestAddress(0x100),
1043             vec![
1044                 (Readable, 16),
1045                 (Readable, 16),
1046                 (Readable, 96),
1047                 (Writable, 64),
1048                 (Writable, 1),
1049                 (Writable, 3),
1050             ],
1051             0,
1052         )
1053         .expect("create_descriptor_chain failed");
1054         let reader = &mut chain.reader;
1055         let writer = &mut chain.writer;
1056 
1057         assert_eq!(reader.bytes_read(), 0);
1058         assert_eq!(writer.bytes_written(), 0);
1059 
1060         let mut buffer = Vec::with_capacity(200);
1061 
1062         assert_eq!(
1063             reader
1064                 .read_to_end(&mut buffer)
1065                 .expect("read should not fail here"),
1066             128
1067         );
1068 
1069         // The writable descriptors are only 68 bytes long.
1070         writer
1071             .write_all(&buffer[..68])
1072             .expect("write should not fail here");
1073 
1074         assert_eq!(reader.available_bytes(), 0);
1075         assert_eq!(reader.bytes_read(), 128);
1076         assert_eq!(writer.available_bytes(), 0);
1077         assert_eq!(writer.bytes_written(), 68);
1078     }
1079 
1080     #[test]
reader_writer_shattered_object()1081     fn reader_writer_shattered_object() {
1082         use DescriptorType::*;
1083 
1084         let memory_start_addr = GuestAddress(0x0);
1085         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1086 
1087         let secret: Le32 = 0x12345678.into();
1088 
1089         // Create a descriptor chain with memory regions that are properly separated.
1090         let mut chain_writer = create_descriptor_chain(
1091             &memory,
1092             GuestAddress(0x0),
1093             GuestAddress(0x100),
1094             vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1095             123,
1096         )
1097         .expect("create_descriptor_chain failed");
1098         let writer = &mut chain_writer.writer;
1099         writer
1100             .write_obj(secret)
1101             .expect("write_obj should not fail here");
1102 
1103         // Now create new descriptor chain pointing to the same memory and try to read it.
1104         let mut chain_reader = create_descriptor_chain(
1105             &memory,
1106             GuestAddress(0x0),
1107             GuestAddress(0x100),
1108             vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1109             123,
1110         )
1111         .expect("create_descriptor_chain failed");
1112         let reader = &mut chain_reader.reader;
1113         match reader.read_obj::<Le32>() {
1114             Err(_) => panic!("read_obj should not fail here"),
1115             Ok(read_secret) => assert_eq!(read_secret, secret),
1116         }
1117     }
1118 
1119     #[test]
reader_unexpected_eof()1120     fn reader_unexpected_eof() {
1121         use DescriptorType::*;
1122 
1123         let memory_start_addr = GuestAddress(0x0);
1124         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1125 
1126         let mut chain = create_descriptor_chain(
1127             &memory,
1128             GuestAddress(0x0),
1129             GuestAddress(0x100),
1130             vec![(Readable, 256), (Readable, 256)],
1131             0,
1132         )
1133         .expect("create_descriptor_chain failed");
1134 
1135         let reader = &mut chain.reader;
1136 
1137         let mut buf = vec![0; 1024];
1138 
1139         assert_eq!(
1140             reader
1141                 .read_exact(&mut buf[..])
1142                 .expect_err("read more bytes than available")
1143                 .kind(),
1144             io::ErrorKind::UnexpectedEof
1145         );
1146     }
1147 
1148     #[test]
split_border()1149     fn split_border() {
1150         use DescriptorType::*;
1151 
1152         let memory_start_addr = GuestAddress(0x0);
1153         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1154 
1155         let mut chain = create_descriptor_chain(
1156             &memory,
1157             GuestAddress(0x0),
1158             GuestAddress(0x100),
1159             vec![
1160                 (Readable, 16),
1161                 (Readable, 16),
1162                 (Readable, 96),
1163                 (Writable, 64),
1164                 (Writable, 1),
1165                 (Writable, 3),
1166             ],
1167             0,
1168         )
1169         .expect("create_descriptor_chain failed");
1170         let reader = &mut chain.reader;
1171 
1172         let other = reader.split_at(32);
1173         assert_eq!(reader.available_bytes(), 32);
1174         assert_eq!(other.available_bytes(), 96);
1175     }
1176 
1177     #[test]
split_middle()1178     fn split_middle() {
1179         use DescriptorType::*;
1180 
1181         let memory_start_addr = GuestAddress(0x0);
1182         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1183 
1184         let mut chain = create_descriptor_chain(
1185             &memory,
1186             GuestAddress(0x0),
1187             GuestAddress(0x100),
1188             vec![
1189                 (Readable, 16),
1190                 (Readable, 16),
1191                 (Readable, 96),
1192                 (Writable, 64),
1193                 (Writable, 1),
1194                 (Writable, 3),
1195             ],
1196             0,
1197         )
1198         .expect("create_descriptor_chain failed");
1199         let reader = &mut chain.reader;
1200 
1201         let other = reader.split_at(24);
1202         assert_eq!(reader.available_bytes(), 24);
1203         assert_eq!(other.available_bytes(), 104);
1204     }
1205 
1206     #[test]
split_end()1207     fn split_end() {
1208         use DescriptorType::*;
1209 
1210         let memory_start_addr = GuestAddress(0x0);
1211         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1212 
1213         let mut chain = create_descriptor_chain(
1214             &memory,
1215             GuestAddress(0x0),
1216             GuestAddress(0x100),
1217             vec![
1218                 (Readable, 16),
1219                 (Readable, 16),
1220                 (Readable, 96),
1221                 (Writable, 64),
1222                 (Writable, 1),
1223                 (Writable, 3),
1224             ],
1225             0,
1226         )
1227         .expect("create_descriptor_chain failed");
1228         let reader = &mut chain.reader;
1229 
1230         let other = reader.split_at(128);
1231         assert_eq!(reader.available_bytes(), 128);
1232         assert_eq!(other.available_bytes(), 0);
1233     }
1234 
1235     #[test]
split_beginning()1236     fn split_beginning() {
1237         use DescriptorType::*;
1238 
1239         let memory_start_addr = GuestAddress(0x0);
1240         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1241 
1242         let mut chain = create_descriptor_chain(
1243             &memory,
1244             GuestAddress(0x0),
1245             GuestAddress(0x100),
1246             vec![
1247                 (Readable, 16),
1248                 (Readable, 16),
1249                 (Readable, 96),
1250                 (Writable, 64),
1251                 (Writable, 1),
1252                 (Writable, 3),
1253             ],
1254             0,
1255         )
1256         .expect("create_descriptor_chain failed");
1257         let reader = &mut chain.reader;
1258 
1259         let other = reader.split_at(0);
1260         assert_eq!(reader.available_bytes(), 0);
1261         assert_eq!(other.available_bytes(), 128);
1262     }
1263 
1264     #[test]
split_outofbounds()1265     fn split_outofbounds() {
1266         use DescriptorType::*;
1267 
1268         let memory_start_addr = GuestAddress(0x0);
1269         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1270 
1271         let mut chain = create_descriptor_chain(
1272             &memory,
1273             GuestAddress(0x0),
1274             GuestAddress(0x100),
1275             vec![
1276                 (Readable, 16),
1277                 (Readable, 16),
1278                 (Readable, 96),
1279                 (Writable, 64),
1280                 (Writable, 1),
1281                 (Writable, 3),
1282             ],
1283             0,
1284         )
1285         .expect("create_descriptor_chain failed");
1286         let reader = &mut chain.reader;
1287 
1288         let other = reader.split_at(256);
1289         assert_eq!(
1290             other.available_bytes(),
1291             0,
1292             "Reader returned from out-of-bounds split still has available bytes"
1293         );
1294     }
1295 
1296     #[test]
read_full()1297     fn read_full() {
1298         use DescriptorType::*;
1299 
1300         let memory_start_addr = GuestAddress(0x0);
1301         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1302 
1303         let mut chain = create_descriptor_chain(
1304             &memory,
1305             GuestAddress(0x0),
1306             GuestAddress(0x100),
1307             vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1308             0,
1309         )
1310         .expect("create_descriptor_chain failed");
1311         let reader = &mut chain.reader;
1312 
1313         let mut buf = [0u8; 64];
1314         assert_eq!(
1315             reader.read(&mut buf[..]).expect("failed to read to buffer"),
1316             48
1317         );
1318     }
1319 
1320     #[test]
write_full()1321     fn write_full() {
1322         use DescriptorType::*;
1323 
1324         let memory_start_addr = GuestAddress(0x0);
1325         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1326 
1327         let mut chain = create_descriptor_chain(
1328             &memory,
1329             GuestAddress(0x0),
1330             GuestAddress(0x100),
1331             vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1332             0,
1333         )
1334         .expect("create_descriptor_chain failed");
1335         let writer = &mut chain.writer;
1336 
1337         let buf = [0xdeu8; 64];
1338         assert_eq!(
1339             writer.write(&buf[..]).expect("failed to write from buffer"),
1340             48
1341         );
1342     }
1343 
1344     #[test]
consume_collect()1345     fn consume_collect() {
1346         use DescriptorType::*;
1347 
1348         let memory_start_addr = GuestAddress(0x0);
1349         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1350         let vs: Vec<Le64> = vec![
1351             0x0101010101010101.into(),
1352             0x0202020202020202.into(),
1353             0x0303030303030303.into(),
1354         ];
1355 
1356         let mut write_chain = create_descriptor_chain(
1357             &memory,
1358             GuestAddress(0x0),
1359             GuestAddress(0x100),
1360             vec![(Writable, 24)],
1361             0,
1362         )
1363         .expect("create_descriptor_chain failed");
1364         let writer = &mut write_chain.writer;
1365         writer
1366             .consume(vs.clone())
1367             .expect("failed to consume() a vector");
1368 
1369         let mut read_chain = create_descriptor_chain(
1370             &memory,
1371             GuestAddress(0x0),
1372             GuestAddress(0x100),
1373             vec![(Readable, 24)],
1374             0,
1375         )
1376         .expect("create_descriptor_chain failed");
1377         let reader = &mut read_chain.reader;
1378         let vs_read = reader
1379             .collect::<io::Result<Vec<Le64>>, _>()
1380             .expect("failed to collect() values");
1381         assert_eq!(vs, vs_read);
1382     }
1383 
1384     #[test]
get_remaining_region_with_count()1385     fn get_remaining_region_with_count() {
1386         use DescriptorType::*;
1387 
1388         let memory_start_addr = GuestAddress(0x0);
1389         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1390 
1391         let chain = create_descriptor_chain(
1392             &memory,
1393             GuestAddress(0x0),
1394             GuestAddress(0x100),
1395             vec![
1396                 (Readable, 16),
1397                 (Readable, 16),
1398                 (Readable, 96),
1399                 (Writable, 64),
1400                 (Writable, 1),
1401                 (Writable, 3),
1402             ],
1403             0,
1404         )
1405         .expect("create_descriptor_chain failed");
1406 
1407         let Reader {
1408             mem: _,
1409             mut regions,
1410         } = chain.reader;
1411 
1412         let drain = regions
1413             .get_remaining_regions_with_count(usize::MAX)
1414             .fold(0usize, |total, region| total + region.len);
1415         assert_eq!(drain, 128);
1416 
1417         let exact = regions
1418             .get_remaining_regions_with_count(32)
1419             .fold(0usize, |total, region| total + region.len);
1420         assert!(exact > 0);
1421         assert!(exact <= 32);
1422 
1423         let split = regions
1424             .get_remaining_regions_with_count(24)
1425             .fold(0usize, |total, region| total + region.len);
1426         assert!(split > 0);
1427         assert!(split <= 24);
1428 
1429         regions.consume(64);
1430 
1431         let first = regions
1432             .get_remaining_regions_with_count(8)
1433             .fold(0usize, |total, region| total + region.len);
1434         assert!(first > 0);
1435         assert!(first <= 8);
1436     }
1437 
1438     #[test]
get_remaining_with_count()1439     fn get_remaining_with_count() {
1440         use DescriptorType::*;
1441 
1442         let memory_start_addr = GuestAddress(0x0);
1443         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1444 
1445         let chain = create_descriptor_chain(
1446             &memory,
1447             GuestAddress(0x0),
1448             GuestAddress(0x100),
1449             vec![
1450                 (Readable, 16),
1451                 (Readable, 16),
1452                 (Readable, 96),
1453                 (Writable, 64),
1454                 (Writable, 1),
1455                 (Writable, 3),
1456             ],
1457             0,
1458         )
1459         .expect("create_descriptor_chain failed");
1460         let Reader {
1461             mem: _,
1462             mut regions,
1463         } = chain.reader;
1464 
1465         let drain = regions
1466             .get_remaining_with_count(&memory, usize::MAX)
1467             .iter()
1468             .fold(0usize, |total, iov| total + iov.size());
1469         assert_eq!(drain, 128);
1470 
1471         let exact = regions
1472             .get_remaining_with_count(&memory, 32)
1473             .iter()
1474             .fold(0usize, |total, iov| total + iov.size());
1475         assert!(exact > 0);
1476         assert!(exact <= 32);
1477 
1478         let split = regions
1479             .get_remaining_with_count(&memory, 24)
1480             .iter()
1481             .fold(0usize, |total, iov| total + iov.size());
1482         assert!(split > 0);
1483         assert!(split <= 24);
1484 
1485         regions.consume(64);
1486 
1487         let first = regions
1488             .get_remaining_with_count(&memory, 8)
1489             .iter()
1490             .fold(0usize, |total, iov| total + iov.size());
1491         assert!(first > 0);
1492         assert!(first <= 8);
1493     }
1494 
1495     #[test]
reader_peek_obj()1496     fn reader_peek_obj() {
1497         use DescriptorType::*;
1498 
1499         let memory_start_addr = GuestAddress(0x0);
1500         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1501 
1502         // Write test data to memory.
1503         memory
1504             .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
1505             .unwrap();
1506         memory
1507             .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
1508             .unwrap();
1509 
1510         let mut chain_reader = create_descriptor_chain(
1511             &memory,
1512             GuestAddress(0x0),
1513             GuestAddress(0x100),
1514             vec![(Readable, 2), (Readable, 2)],
1515             0x100 - 2,
1516         )
1517         .expect("create_descriptor_chain failed");
1518         let reader = &mut chain_reader.reader;
1519 
1520         // peek_obj() at the beginning of the chain should return the first object.
1521         let peek1 = reader.peek_obj::<Le16>().unwrap();
1522         assert_eq!(peek1, Le16::from(0xBEEF));
1523 
1524         // peek_obj() again should return the same object, since it was not consumed.
1525         let peek2 = reader.peek_obj::<Le16>().unwrap();
1526         assert_eq!(peek2, Le16::from(0xBEEF));
1527 
1528         // peek_obj() of an object spanning two descriptors should copy from both.
1529         let peek3 = reader.peek_obj::<Le32>().unwrap();
1530         assert_eq!(peek3, Le32::from(0xDEADBEEF));
1531 
1532         // read_obj() should return the first object.
1533         let read1 = reader.read_obj::<Le16>().unwrap();
1534         assert_eq!(read1, Le16::from(0xBEEF));
1535 
1536         // peek_obj() of a value that is larger than the rest of the chain should fail.
1537         reader
1538             .peek_obj::<Le32>()
1539             .expect_err("peek_obj past end of chain");
1540 
1541         // read_obj() again should return the second object.
1542         let read2 = reader.read_obj::<Le16>().unwrap();
1543         assert_eq!(read2, Le16::from(0xDEAD));
1544 
1545         // peek_obj() should fail at the end of the chain.
1546         reader
1547             .peek_obj::<Le16>()
1548             .expect_err("peek_obj past end of chain");
1549     }
1550 
1551     #[test]
region_reader_failing_io()1552     fn region_reader_failing_io() {
1553         let ex = Executor::new().unwrap();
1554         ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1555     }
region_reader_failing_io_async(ex: &Executor)1556     async fn region_reader_failing_io_async(ex: &Executor) {
1557         use DescriptorType::*;
1558 
1559         let memory_start_addr = GuestAddress(0x0);
1560         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1561 
1562         let mut chain = create_descriptor_chain(
1563             &memory,
1564             GuestAddress(0x0),
1565             GuestAddress(0x100),
1566             vec![(Readable, 256), (Readable, 256)],
1567             0,
1568         )
1569         .expect("create_descriptor_chain failed");
1570 
1571         let reader = &mut chain.reader;
1572 
1573         // Open a file in read-only mode so writes to it to trigger an I/O error.
1574         let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
1575         let ro_file =
1576             File::open(named_temp_file.path()).expect("failed to open temp file read only");
1577         let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1578 
1579         reader
1580             .read_exact_to_at_fut(&async_ro_file, 512, 0)
1581             .await
1582             .expect_err("successfully read more bytes than SingleFileDisk size");
1583 
1584         // The write above should have failed entirely, so we end up not writing any bytes at all.
1585         assert_eq!(reader.available_bytes(), 512);
1586         assert_eq!(reader.bytes_read(), 0);
1587     }
1588 
1589     #[test]
region_writer_failing_io()1590     fn region_writer_failing_io() {
1591         let ex = Executor::new().unwrap();
1592         ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1593     }
region_writer_failing_io_async(ex: &Executor)1594     async fn region_writer_failing_io_async(ex: &Executor) {
1595         use DescriptorType::*;
1596 
1597         let memory_start_addr = GuestAddress(0x0);
1598         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1599 
1600         let mut chain = create_descriptor_chain(
1601             &memory,
1602             GuestAddress(0x0),
1603             GuestAddress(0x100),
1604             vec![(Writable, 256), (Writable, 256)],
1605             0,
1606         )
1607         .expect("create_descriptor_chain failed");
1608 
1609         let writer = &mut chain.writer;
1610 
1611         let file = tempfile().expect("failed to create temp file");
1612 
1613         file.set_len(384).unwrap();
1614         let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1615 
1616         writer
1617             .write_all_from_at_fut(&async_file, 512, 0)
1618             .await
1619             .expect_err("successfully wrote more bytes than in SingleFileDisk");
1620 
1621         assert_eq!(writer.available_bytes(), 128);
1622         assert_eq!(writer.bytes_written(), 384);
1623     }
1624 }
1625