1 // Copyright 2019 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::borrow::Cow;
6 use std::cmp;
7 use std::convert::TryInto;
8 use std::fmt::{self, Display};
9 use std::io::{self, Read, Write};
10 use std::iter::FromIterator;
11 use std::marker::PhantomData;
12 use std::mem::{size_of, MaybeUninit};
13 use std::ptr::copy_nonoverlapping;
14 use std::result;
15 use std::sync::Arc;
16
17 use base::{FileReadWriteAtVolatile, FileReadWriteVolatile};
18 use cros_async::MemRegion;
19 use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice};
20 use disk::AsyncDisk;
21 use smallvec::SmallVec;
22 use vm_memory::{GuestAddress, GuestMemory};
23
24 use super::DescriptorChain;
25
26 #[derive(Debug)]
27 pub enum Error {
28 DescriptorChainOverflow,
29 GuestMemoryError(vm_memory::GuestMemoryError),
30 InvalidChain,
31 IoError(io::Error),
32 SplitOutOfBounds(usize),
33 VolatileMemoryError(VolatileMemoryError),
34 }
35
36 impl Display for Error {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result37 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38 use self::Error::*;
39
40 match self {
41 DescriptorChainOverflow => write!(
42 f,
43 "the combined length of all the buffers in a `DescriptorChain` would overflow"
44 ),
45 GuestMemoryError(e) => write!(f, "descriptor guest memory error: {}", e),
46 InvalidChain => write!(f, "invalid descriptor chain"),
47 IoError(e) => write!(f, "descriptor I/O error: {}", e),
48 SplitOutOfBounds(off) => write!(f, "`DescriptorChain` split is out of bounds: {}", off),
49 VolatileMemoryError(e) => write!(f, "volatile memory error: {}", e),
50 }
51 }
52 }
53
54 pub type Result<T> = result::Result<T, Error>;
55
56 impl std::error::Error for Error {}
57
58 #[derive(Clone)]
59 struct DescriptorChainRegions {
60 regions: SmallVec<[MemRegion; 16]>,
61 current: usize,
62 bytes_consumed: usize,
63 }
64
65 impl DescriptorChainRegions {
available_bytes(&self) -> usize66 fn available_bytes(&self) -> usize {
67 // This is guaranteed not to overflow because the total length of the chain
68 // is checked during all creations of `DescriptorChainRegions` (see
69 // `Reader::new()` and `Writer::new()`).
70 self.get_remaining_regions()
71 .iter()
72 .fold(0usize, |count, region| count + region.len)
73 }
74
bytes_consumed(&self) -> usize75 fn bytes_consumed(&self) -> usize {
76 self.bytes_consumed
77 }
78
79 /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
80 /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
81 /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
82 /// to `consume` will return the same data.
get_remaining_regions(&self) -> &[MemRegion]83 fn get_remaining_regions(&self) -> &[MemRegion] {
84 &self.regions[self.current..]
85 }
86
87 /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
88 /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
89 /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
90 /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>91 fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
92 self.get_remaining_regions()
93 .iter()
94 .filter_map(|region| {
95 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
96 .ok()
97 })
98 .collect()
99 }
100
101 /// Like `get_remaining` but guarantees that the combined length of all the returned iovecs is
102 /// not greater than `count`. The combined length of the returned iovecs may be less than
103 /// `count` but will always be greater than 0 as long as there is still space left in the
104 /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]>105 fn get_remaining_regions_with_count(&self, count: usize) -> Cow<[MemRegion]> {
106 let regions = self.get_remaining_regions();
107 let mut region_count = 0;
108 let mut rem = count;
109 for region in regions {
110 if rem < region.len {
111 break;
112 }
113
114 region_count += 1;
115 rem -= region.len;
116 }
117
118 // Special case where the number of bytes to be copied is smaller than the `size()` of the
119 // first regions.
120 if region_count == 0 && !regions.is_empty() && count > 0 {
121 debug_assert!(count < regions[0].len);
122 // Safe because we know that count is smaller than the length of the first slice.
123 Cow::Owned(vec![MemRegion {
124 offset: regions[0].offset,
125 len: count,
126 }])
127 } else {
128 Cow::Borrowed(®ions[..region_count])
129 }
130 }
131
132 /// Like 'get_remaining_with_count' except convert the offsets to volatile slices in the
133 /// 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>134 fn get_remaining_with_count<'mem>(
135 &self,
136 mem: &'mem GuestMemory,
137 count: usize,
138 ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
139 self.get_remaining_regions_with_count(count)
140 .iter()
141 .filter_map(|region| {
142 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
143 .ok()
144 })
145 .collect()
146 }
147
148 /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
149 /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)150 fn consume(&mut self, mut count: usize) {
151 // The implementation is adapted from `IoSlice::advance` in libstd. We can't use
152 // `get_remaining` here because then the compiler complains that `self.current` is already
153 // borrowed and doesn't allow us to modify it. We also need to borrow the iovecs mutably.
154 let current = self.current;
155 for region in &mut self.regions[current..] {
156 if count == 0 {
157 break;
158 }
159
160 let consumed = if count < region.len {
161 // Safe because we know that the iovec pointed to valid memory and we are adding a
162 // value that is smaller than the length of the memory.
163 *region = MemRegion {
164 offset: region.offset + count as u64,
165 len: region.len - count,
166 };
167 count
168 } else {
169 self.current += 1;
170 region.len
171 };
172
173 // This shouldn't overflow because `consumed <= buf.size()` and we already verified
174 // that adding all `buf.size()` values will not overflow when the Reader/Writer was
175 // constructed.
176 self.bytes_consumed += consumed;
177 count -= consumed;
178 }
179 }
180
split_at(&mut self, offset: usize) -> DescriptorChainRegions181 fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
182 let mut other = self.clone();
183 other.consume(offset);
184 other.bytes_consumed = 0;
185
186 let mut rem = offset;
187 let mut end = self.current;
188 for region in &mut self.regions[self.current..] {
189 if rem <= region.len {
190 region.len = rem;
191 break;
192 }
193
194 end += 1;
195 rem -= region.len;
196 }
197
198 self.regions.truncate(end + 1);
199
200 other
201 }
202 }
203
204 /// Provides high-level interface over the sequence of memory regions
205 /// defined by readable descriptors in the descriptor chain.
206 ///
207 /// Note that virtio spec requires driver to place any device-writable
208 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
209 /// Reader will skip iterating over descriptor chain when first writable
210 /// descriptor is encountered.
211 #[derive(Clone)]
212 pub struct Reader {
213 mem: GuestMemory,
214 regions: DescriptorChainRegions,
215 }
216
217 /// An iterator over `DataInit` objects on readable descriptors in the descriptor chain.
218 pub struct ReaderIterator<'a, T: DataInit> {
219 reader: &'a mut Reader,
220 phantom: PhantomData<T>,
221 }
222
223 impl<'a, T: DataInit> Iterator for ReaderIterator<'a, T> {
224 type Item = io::Result<T>;
225
next(&mut self) -> Option<io::Result<T>>226 fn next(&mut self) -> Option<io::Result<T>> {
227 if self.reader.available_bytes() == 0 {
228 None
229 } else {
230 Some(self.reader.read_obj())
231 }
232 }
233 }
234
235 impl Reader {
236 /// Construct a new Reader wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader>237 pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader> {
238 // TODO(jstaron): Update this code to take the indirect descriptors into account.
239 let mut total_len: usize = 0;
240 let regions = desc_chain
241 .into_iter()
242 .readable()
243 .map(|desc| {
244 // Verify that summing the descriptor sizes does not overflow.
245 // This can happen if a driver tricks a device into reading more data than
246 // fits in a `usize`.
247 total_len = total_len
248 .checked_add(desc.len as usize)
249 .ok_or(Error::DescriptorChainOverflow)?;
250
251 // Check that all the regions are totally contained in GuestMemory.
252 mem.get_slice_at_addr(
253 desc.addr,
254 desc.len.try_into().expect("u32 doesn't fit in usize"),
255 )
256 .map_err(Error::GuestMemoryError)?;
257
258 Ok(MemRegion {
259 offset: desc.addr.0,
260 len: desc.len.try_into().expect("u32 doesn't fit in usize"),
261 })
262 })
263 .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
264 Ok(Reader {
265 mem,
266 regions: DescriptorChainRegions {
267 regions,
268 current: 0,
269 bytes_consumed: 0,
270 },
271 })
272 }
273
274 /// Reads an object from the descriptor chain buffer.
read_obj<T: DataInit>(&mut self) -> io::Result<T>275 pub fn read_obj<T: DataInit>(&mut self) -> io::Result<T> {
276 let mut obj = MaybeUninit::<T>::uninit();
277
278 // Safe because `MaybeUninit` guarantees that the pointer is valid for
279 // `size_of::<T>()` bytes.
280 let buf = unsafe {
281 ::std::slice::from_raw_parts_mut(obj.as_mut_ptr() as *mut u8, size_of::<T>())
282 };
283
284 self.read_exact(buf)?;
285
286 // Safe because any type that implements `DataInit` can be considered initialized
287 // even if it is filled with random data.
288 Ok(unsafe { obj.assume_init() })
289 }
290
291 /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
292 /// them as a collection. Returns an error if the size of the remaining data is indivisible by
293 /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C294 pub fn collect<C: FromIterator<io::Result<T>>, T: DataInit>(&mut self) -> C {
295 self.iter().collect()
296 }
297
298 /// Creates an iterator for sequentially reading `DataInit` objects from the `Reader`.
299 /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
300 /// doesn't require the objects to be stored in a separate collection.
iter<T: DataInit>(&mut self) -> ReaderIterator<T>301 pub fn iter<T: DataInit>(&mut self) -> ReaderIterator<T> {
302 ReaderIterator {
303 reader: self,
304 phantom: PhantomData,
305 }
306 }
307
308 /// Reads data from the descriptor chain buffer into a file descriptor.
309 /// Returns the number of bytes read from the descriptor chain buffer.
310 /// The number of bytes read can be less than `count` if there isn't
311 /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>312 pub fn read_to<F: FileReadWriteVolatile>(
313 &mut self,
314 mut dst: F,
315 count: usize,
316 ) -> io::Result<usize> {
317 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
318 let written = dst.write_vectored_volatile(&iovs[..])?;
319 self.regions.consume(written);
320 Ok(written)
321 }
322
323 /// Reads data from the descriptor chain buffer into a File at offset `off`.
324 /// Returns the number of bytes read from the descriptor chain buffer.
325 /// The number of bytes read can be less than `count` if there isn't
326 /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, count: usize, off: u64, ) -> io::Result<usize>327 pub fn read_to_at<F: FileReadWriteAtVolatile>(
328 &mut self,
329 mut dst: F,
330 count: usize,
331 off: u64,
332 ) -> io::Result<usize> {
333 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
334 let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
335 self.regions.consume(written);
336 Ok(written)
337 }
338
339 /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
340 /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>341 pub fn read_exact_to<F: FileReadWriteVolatile>(
342 &mut self,
343 mut dst: F,
344 mut count: usize,
345 ) -> io::Result<()> {
346 while count > 0 {
347 match self.read_to(&mut dst, count) {
348 Ok(0) => {
349 return Err(io::Error::new(
350 io::ErrorKind::UnexpectedEof,
351 "failed to fill whole buffer",
352 ))
353 }
354 Ok(n) => count -= n,
355 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
356 Err(e) => return Err(e),
357 }
358 }
359
360 Ok(())
361 }
362
363 /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
364 /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, mut dst: F, mut count: usize, mut off: u64, ) -> io::Result<()>365 pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
366 &mut self,
367 mut dst: F,
368 mut count: usize,
369 mut off: u64,
370 ) -> io::Result<()> {
371 while count > 0 {
372 match self.read_to_at(&mut dst, count, off) {
373 Ok(0) => {
374 return Err(io::Error::new(
375 io::ErrorKind::UnexpectedEof,
376 "failed to fill whole buffer",
377 ))
378 }
379 Ok(n) => {
380 count -= n;
381 off += n as u64;
382 }
383 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
384 Err(e) => return Err(e),
385 }
386 }
387
388 Ok(())
389 }
390
391 /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
392 /// Returns the number of bytes read from the descriptor chain buffer.
393 /// The number of bytes read can be less than `count` if there isn't
394 /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>395 pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
396 &mut self,
397 dst: &F,
398 count: usize,
399 off: u64,
400 ) -> disk::Result<usize> {
401 let mem_regions = self.regions.get_remaining_regions_with_count(count);
402 let written = dst
403 .write_from_mem(off, Arc::new(self.mem.clone()), &mem_regions)
404 .await?;
405 self.regions.consume(written);
406 Ok(written)
407 }
408
409 /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
410 /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>411 pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
412 &mut self,
413 dst: &F,
414 mut count: usize,
415 mut off: u64,
416 ) -> disk::Result<()> {
417 while count > 0 {
418 let nread = self.read_to_at_fut(dst, count, off).await?;
419 if nread == 0 {
420 return Err(disk::Error::ReadingData(io::Error::new(
421 io::ErrorKind::UnexpectedEof,
422 "failed to write whole buffer",
423 )));
424 }
425 count -= nread;
426 off += nread as u64;
427 }
428
429 Ok(())
430 }
431
432 /// Returns number of bytes available for reading. May return an error if the combined
433 /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize434 pub fn available_bytes(&self) -> usize {
435 self.regions.available_bytes()
436 }
437
438 /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize439 pub fn bytes_read(&self) -> usize {
440 self.regions.bytes_consumed()
441 }
442
443 /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
444 /// Calling this method does not actually consume any data from the `Reader` and callers should
445 /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>446 pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
447 self.regions.get_remaining(&self.mem)
448 }
449
450 /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
451 /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)452 pub fn consume(&mut self, amt: usize) {
453 self.regions.consume(amt)
454 }
455
456 /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
457 /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
458 /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
459 /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader460 pub fn split_at(&mut self, offset: usize) -> Reader {
461 Reader {
462 mem: self.mem.clone(),
463 regions: self.regions.split_at(offset),
464 }
465 }
466 }
467
468 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>469 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
470 let mut rem = buf;
471 let mut total = 0;
472 for b in self.regions.get_remaining(&self.mem) {
473 if rem.is_empty() {
474 break;
475 }
476
477 let count = cmp::min(rem.len(), b.size());
478
479 // Safe because we have already verified that `b` points to valid memory.
480 unsafe {
481 copy_nonoverlapping(b.as_ptr(), rem.as_mut_ptr(), count);
482 }
483 rem = &mut rem[count..];
484 total += count;
485 }
486
487 self.regions.consume(total);
488 Ok(total)
489 }
490 }
491
492 /// Provides high-level interface over the sequence of memory regions
493 /// defined by writable descriptors in the descriptor chain.
494 ///
495 /// Note that virtio spec requires driver to place any device-writable
496 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
497 /// Writer will start iterating the descriptors from the first writable one and will
498 /// assume that all following descriptors are writable.
499 #[derive(Clone)]
500 pub struct Writer {
501 mem: GuestMemory,
502 regions: DescriptorChainRegions,
503 }
504
505 impl Writer {
506 /// Construct a new Writer wrapper over `desc_chain`.
new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer>507 pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer> {
508 let mut total_len: usize = 0;
509 let regions = desc_chain
510 .into_iter()
511 .writable()
512 .map(|desc| {
513 // Verify that summing the descriptor sizes does not overflow.
514 // This can happen if a driver tricks a device into writing more data than
515 // fits in a `usize`.
516 total_len = total_len
517 .checked_add(desc.len as usize)
518 .ok_or(Error::DescriptorChainOverflow)?;
519
520 mem.get_slice_at_addr(
521 desc.addr,
522 desc.len.try_into().expect("u32 doesn't fit in usize"),
523 )
524 .map_err(Error::GuestMemoryError)?;
525
526 Ok(MemRegion {
527 offset: desc.addr.0,
528 len: desc.len.try_into().expect("u32 doesn't fit in usize"),
529 })
530 })
531 .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
532 Ok(Writer {
533 mem,
534 regions: DescriptorChainRegions {
535 regions,
536 current: 0,
537 bytes_consumed: 0,
538 },
539 })
540 }
541
542 /// Writes an object to the descriptor chain buffer.
write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()>543 pub fn write_obj<T: DataInit>(&mut self, val: T) -> io::Result<()> {
544 self.write_all(val.as_slice())
545 }
546
547 /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
548 /// this doesn't require the values to be stored in an intermediate collection first. It also
549 /// allows callers to choose which elements in a collection to write, for example by using the
550 /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: DataInit, I: Iterator<Item = T>>( &mut self, mut iter: I, ) -> io::Result<()>551 pub fn write_iter<T: DataInit, I: Iterator<Item = T>>(
552 &mut self,
553 mut iter: I,
554 ) -> io::Result<()> {
555 iter.try_for_each(|v| self.write_obj(v))
556 }
557
558 /// Writes a collection of objects into the descriptor chain buffer.
consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>559 pub fn consume<T: DataInit, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
560 self.write_iter(vals.into_iter())
561 }
562
563 /// Returns number of bytes available for writing. May return an error if the combined
564 /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize565 pub fn available_bytes(&self) -> usize {
566 self.regions.available_bytes()
567 }
568
569 /// Writes data to the descriptor chain buffer from a file descriptor.
570 /// Returns the number of bytes written to the descriptor chain buffer.
571 /// The number of bytes written can be less than `count` if
572 /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>573 pub fn write_from<F: FileReadWriteVolatile>(
574 &mut self,
575 mut src: F,
576 count: usize,
577 ) -> io::Result<usize> {
578 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
579 let read = src.read_vectored_volatile(&iovs[..])?;
580 self.regions.consume(read);
581 Ok(read)
582 }
583
584 /// Writes data to the descriptor chain buffer from a File at offset `off`.
585 /// Returns the number of bytes written to the descriptor chain buffer.
586 /// The number of bytes written can be less than `count` if
587 /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, count: usize, off: u64, ) -> io::Result<usize>588 pub fn write_from_at<F: FileReadWriteAtVolatile>(
589 &mut self,
590 mut src: F,
591 count: usize,
592 off: u64,
593 ) -> io::Result<usize> {
594 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
595 let read = src.read_vectored_at_volatile(&iovs[..], off)?;
596 self.regions.consume(read);
597 Ok(read)
598 }
599
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>600 pub fn write_all_from<F: FileReadWriteVolatile>(
601 &mut self,
602 mut src: F,
603 mut count: usize,
604 ) -> io::Result<()> {
605 while count > 0 {
606 match self.write_from(&mut src, count) {
607 Ok(0) => {
608 return Err(io::Error::new(
609 io::ErrorKind::WriteZero,
610 "failed to write whole buffer",
611 ))
612 }
613 Ok(n) => count -= n,
614 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
615 Err(e) => return Err(e),
616 }
617 }
618
619 Ok(())
620 }
621
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, mut src: F, mut count: usize, mut off: u64, ) -> io::Result<()>622 pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
623 &mut self,
624 mut src: F,
625 mut count: usize,
626 mut off: u64,
627 ) -> io::Result<()> {
628 while count > 0 {
629 match self.write_from_at(&mut src, count, off) {
630 Ok(0) => {
631 return Err(io::Error::new(
632 io::ErrorKind::WriteZero,
633 "failed to write whole buffer",
634 ))
635 }
636 Ok(n) => {
637 count -= n;
638 off += n as u64;
639 }
640 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
641 Err(e) => return Err(e),
642 }
643 }
644 Ok(())
645 }
646 /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
647 /// Returns the number of bytes written to the descriptor chain buffer.
648 /// The number of bytes written can be less than `count` if
649 /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>650 pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
651 &mut self,
652 src: &F,
653 count: usize,
654 off: u64,
655 ) -> disk::Result<usize> {
656 let regions = self.regions.get_remaining_regions_with_count(count);
657 let read = src
658 .read_to_mem(off, Arc::new(self.mem.clone()), ®ions)
659 .await?;
660 self.regions.consume(read);
661 Ok(read)
662 }
663
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>664 pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
665 &mut self,
666 src: &F,
667 mut count: usize,
668 mut off: u64,
669 ) -> disk::Result<()> {
670 while count > 0 {
671 let nwritten = self.write_from_at_fut(src, count, off).await?;
672 if nwritten == 0 {
673 return Err(disk::Error::WritingData(io::Error::new(
674 io::ErrorKind::UnexpectedEof,
675 "failed to write whole buffer",
676 )));
677 }
678 count -= nwritten;
679 off += nwritten as u64;
680 }
681 Ok(())
682 }
683
684 /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize685 pub fn bytes_written(&self) -> usize {
686 self.regions.bytes_consumed()
687 }
688
689 /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
690 /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
691 /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
692 /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer693 pub fn split_at(&mut self, offset: usize) -> Writer {
694 Writer {
695 mem: self.mem.clone(),
696 regions: self.regions.split_at(offset),
697 }
698 }
699 }
700
701 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>702 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
703 let mut rem = buf;
704 let mut total = 0;
705 for b in self.regions.get_remaining(&self.mem) {
706 if rem.is_empty() {
707 break;
708 }
709
710 let count = cmp::min(rem.len(), b.size());
711 // Safe because we have already verified that `vs` points to valid memory.
712 unsafe {
713 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
714 }
715 rem = &rem[count..];
716 total += count;
717 }
718
719 self.regions.consume(total);
720 Ok(total)
721 }
722
flush(&mut self) -> io::Result<()>723 fn flush(&mut self) -> io::Result<()> {
724 // Nothing to flush since the writes go straight into the buffer.
725 Ok(())
726 }
727 }
728
729 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
730 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
731
732 #[derive(Copy, Clone, PartialEq, Eq)]
733 pub enum DescriptorType {
734 Readable,
735 Writable,
736 }
737
738 #[derive(Copy, Clone, Debug)]
739 #[repr(C)]
740 struct virtq_desc {
741 addr: Le64,
742 len: Le32,
743 flags: Le16,
744 next: Le16,
745 }
746
747 // Safe because it only has data and has no implicit padding.
748 unsafe impl DataInit for virtq_desc {}
749
750 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> Result<DescriptorChain>751 pub fn create_descriptor_chain(
752 memory: &GuestMemory,
753 descriptor_array_addr: GuestAddress,
754 mut buffers_start_addr: GuestAddress,
755 descriptors: Vec<(DescriptorType, u32)>,
756 spaces_between_regions: u32,
757 ) -> Result<DescriptorChain> {
758 let descriptors_len = descriptors.len();
759 for (index, (type_, size)) in descriptors.into_iter().enumerate() {
760 let mut flags = 0;
761 if let DescriptorType::Writable = type_ {
762 flags |= VIRTQ_DESC_F_WRITE;
763 }
764 if index + 1 < descriptors_len {
765 flags |= VIRTQ_DESC_F_NEXT;
766 }
767
768 let index = index as u16;
769 let desc = virtq_desc {
770 addr: buffers_start_addr.offset().into(),
771 len: size.into(),
772 flags: flags.into(),
773 next: (index + 1).into(),
774 };
775
776 let offset = size + spaces_between_regions;
777 buffers_start_addr = buffers_start_addr
778 .checked_add(offset as u64)
779 .ok_or(Error::InvalidChain)?;
780
781 let _ = memory.write_obj_at_addr(
782 desc,
783 descriptor_array_addr
784 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
785 .ok_or(Error::InvalidChain)?,
786 );
787 }
788
789 DescriptorChain::checked_new(memory, descriptor_array_addr, 0x100, 0, 0)
790 .ok_or(Error::InvalidChain)
791 }
792
793 #[cfg(test)]
794 mod tests {
795 use super::*;
796 use std::fs::File;
797 use tempfile::tempfile;
798
799 use cros_async::Executor;
800
801 #[test]
reader_test_simple_chain()802 fn reader_test_simple_chain() {
803 use DescriptorType::*;
804
805 let memory_start_addr = GuestAddress(0x0);
806 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
807
808 let chain = create_descriptor_chain(
809 &memory,
810 GuestAddress(0x0),
811 GuestAddress(0x100),
812 vec![
813 (Readable, 8),
814 (Readable, 16),
815 (Readable, 18),
816 (Readable, 64),
817 ],
818 0,
819 )
820 .expect("create_descriptor_chain failed");
821 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
822 assert_eq!(reader.available_bytes(), 106);
823 assert_eq!(reader.bytes_read(), 0);
824
825 let mut buffer = [0 as u8; 64];
826 if let Err(_) = reader.read_exact(&mut buffer) {
827 panic!("read_exact should not fail here");
828 }
829
830 assert_eq!(reader.available_bytes(), 42);
831 assert_eq!(reader.bytes_read(), 64);
832
833 match reader.read(&mut buffer) {
834 Err(_) => panic!("read should not fail here"),
835 Ok(length) => assert_eq!(length, 42),
836 }
837
838 assert_eq!(reader.available_bytes(), 0);
839 assert_eq!(reader.bytes_read(), 106);
840 }
841
842 #[test]
writer_test_simple_chain()843 fn writer_test_simple_chain() {
844 use DescriptorType::*;
845
846 let memory_start_addr = GuestAddress(0x0);
847 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
848
849 let chain = create_descriptor_chain(
850 &memory,
851 GuestAddress(0x0),
852 GuestAddress(0x100),
853 vec![
854 (Writable, 8),
855 (Writable, 16),
856 (Writable, 18),
857 (Writable, 64),
858 ],
859 0,
860 )
861 .expect("create_descriptor_chain failed");
862 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
863 assert_eq!(writer.available_bytes(), 106);
864 assert_eq!(writer.bytes_written(), 0);
865
866 let mut buffer = [0 as u8; 64];
867 if let Err(_) = writer.write_all(&mut buffer) {
868 panic!("write_all should not fail here");
869 }
870
871 assert_eq!(writer.available_bytes(), 42);
872 assert_eq!(writer.bytes_written(), 64);
873
874 match writer.write(&mut buffer) {
875 Err(_) => panic!("write should not fail here"),
876 Ok(length) => assert_eq!(length, 42),
877 }
878
879 assert_eq!(writer.available_bytes(), 0);
880 assert_eq!(writer.bytes_written(), 106);
881 }
882
883 #[test]
reader_test_incompatible_chain()884 fn reader_test_incompatible_chain() {
885 use DescriptorType::*;
886
887 let memory_start_addr = GuestAddress(0x0);
888 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
889
890 let chain = create_descriptor_chain(
891 &memory,
892 GuestAddress(0x0),
893 GuestAddress(0x100),
894 vec![(Writable, 8)],
895 0,
896 )
897 .expect("create_descriptor_chain failed");
898 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
899 assert_eq!(reader.available_bytes(), 0);
900 assert_eq!(reader.bytes_read(), 0);
901
902 assert!(reader.read_obj::<u8>().is_err());
903
904 assert_eq!(reader.available_bytes(), 0);
905 assert_eq!(reader.bytes_read(), 0);
906 }
907
908 #[test]
writer_test_incompatible_chain()909 fn writer_test_incompatible_chain() {
910 use DescriptorType::*;
911
912 let memory_start_addr = GuestAddress(0x0);
913 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
914
915 let chain = create_descriptor_chain(
916 &memory,
917 GuestAddress(0x0),
918 GuestAddress(0x100),
919 vec![(Readable, 8)],
920 0,
921 )
922 .expect("create_descriptor_chain failed");
923 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
924 assert_eq!(writer.available_bytes(), 0);
925 assert_eq!(writer.bytes_written(), 0);
926
927 assert!(writer.write_obj(0u8).is_err());
928
929 assert_eq!(writer.available_bytes(), 0);
930 assert_eq!(writer.bytes_written(), 0);
931 }
932
933 #[test]
reader_failing_io()934 fn reader_failing_io() {
935 use DescriptorType::*;
936
937 let memory_start_addr = GuestAddress(0x0);
938 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
939
940 let chain = create_descriptor_chain(
941 &memory,
942 GuestAddress(0x0),
943 GuestAddress(0x100),
944 vec![(Readable, 256), (Readable, 256)],
945 0,
946 )
947 .expect("create_descriptor_chain failed");
948
949 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
950
951 // Open a file in read-only mode so writes to it to trigger an I/O error.
952 let mut ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
953
954 reader
955 .read_exact_to(&mut ro_file, 512)
956 .expect_err("successfully read more bytes than SharedMemory size");
957
958 // The write above should have failed entirely, so we end up not writing any bytes at all.
959 assert_eq!(reader.available_bytes(), 512);
960 assert_eq!(reader.bytes_read(), 0);
961 }
962
963 #[test]
writer_failing_io()964 fn writer_failing_io() {
965 use DescriptorType::*;
966
967 let memory_start_addr = GuestAddress(0x0);
968 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
969
970 let chain = create_descriptor_chain(
971 &memory,
972 GuestAddress(0x0),
973 GuestAddress(0x100),
974 vec![(Writable, 256), (Writable, 256)],
975 0,
976 )
977 .expect("create_descriptor_chain failed");
978
979 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
980
981 let mut file = tempfile().unwrap();
982
983 file.set_len(384).unwrap();
984
985 writer
986 .write_all_from(&mut file, 512)
987 .expect_err("successfully wrote more bytes than in SharedMemory");
988
989 assert_eq!(writer.available_bytes(), 128);
990 assert_eq!(writer.bytes_written(), 384);
991 }
992
993 #[test]
reader_writer_shared_chain()994 fn reader_writer_shared_chain() {
995 use DescriptorType::*;
996
997 let memory_start_addr = GuestAddress(0x0);
998 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
999
1000 let chain = create_descriptor_chain(
1001 &memory,
1002 GuestAddress(0x0),
1003 GuestAddress(0x100),
1004 vec![
1005 (Readable, 16),
1006 (Readable, 16),
1007 (Readable, 96),
1008 (Writable, 64),
1009 (Writable, 1),
1010 (Writable, 3),
1011 ],
1012 0,
1013 )
1014 .expect("create_descriptor_chain failed");
1015 let mut reader =
1016 Reader::new(memory.clone(), chain.clone()).expect("failed to create Reader");
1017 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1018
1019 assert_eq!(reader.bytes_read(), 0);
1020 assert_eq!(writer.bytes_written(), 0);
1021
1022 let mut buffer = Vec::with_capacity(200);
1023
1024 assert_eq!(
1025 reader
1026 .read_to_end(&mut buffer)
1027 .expect("read should not fail here"),
1028 128
1029 );
1030
1031 // The writable descriptors are only 68 bytes long.
1032 writer
1033 .write_all(&buffer[..68])
1034 .expect("write should not fail here");
1035
1036 assert_eq!(reader.available_bytes(), 0);
1037 assert_eq!(reader.bytes_read(), 128);
1038 assert_eq!(writer.available_bytes(), 0);
1039 assert_eq!(writer.bytes_written(), 68);
1040 }
1041
1042 #[test]
reader_writer_shattered_object()1043 fn reader_writer_shattered_object() {
1044 use DescriptorType::*;
1045
1046 let memory_start_addr = GuestAddress(0x0);
1047 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1048
1049 let secret: Le32 = 0x12345678.into();
1050
1051 // Create a descriptor chain with memory regions that are properly separated.
1052 let chain_writer = create_descriptor_chain(
1053 &memory,
1054 GuestAddress(0x0),
1055 GuestAddress(0x100),
1056 vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1057 123,
1058 )
1059 .expect("create_descriptor_chain failed");
1060 let mut writer =
1061 Writer::new(memory.clone(), chain_writer).expect("failed to create Writer");
1062 if let Err(_) = writer.write_obj(secret) {
1063 panic!("write_obj should not fail here");
1064 }
1065
1066 // Now create new descriptor chain pointing to the same memory and try to read it.
1067 let chain_reader = create_descriptor_chain(
1068 &memory,
1069 GuestAddress(0x0),
1070 GuestAddress(0x100),
1071 vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1072 123,
1073 )
1074 .expect("create_descriptor_chain failed");
1075 let mut reader =
1076 Reader::new(memory.clone(), chain_reader).expect("failed to create Reader");
1077 match reader.read_obj::<Le32>() {
1078 Err(_) => panic!("read_obj should not fail here"),
1079 Ok(read_secret) => assert_eq!(read_secret, secret),
1080 }
1081 }
1082
1083 #[test]
reader_unexpected_eof()1084 fn reader_unexpected_eof() {
1085 use DescriptorType::*;
1086
1087 let memory_start_addr = GuestAddress(0x0);
1088 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1089
1090 let chain = create_descriptor_chain(
1091 &memory,
1092 GuestAddress(0x0),
1093 GuestAddress(0x100),
1094 vec![(Readable, 256), (Readable, 256)],
1095 0,
1096 )
1097 .expect("create_descriptor_chain failed");
1098
1099 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1100
1101 let mut buf = Vec::with_capacity(1024);
1102 buf.resize(1024, 0);
1103
1104 assert_eq!(
1105 reader
1106 .read_exact(&mut buf[..])
1107 .expect_err("read more bytes than available")
1108 .kind(),
1109 io::ErrorKind::UnexpectedEof
1110 );
1111 }
1112
1113 #[test]
split_border()1114 fn split_border() {
1115 use DescriptorType::*;
1116
1117 let memory_start_addr = GuestAddress(0x0);
1118 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1119
1120 let chain = create_descriptor_chain(
1121 &memory,
1122 GuestAddress(0x0),
1123 GuestAddress(0x100),
1124 vec![
1125 (Readable, 16),
1126 (Readable, 16),
1127 (Readable, 96),
1128 (Writable, 64),
1129 (Writable, 1),
1130 (Writable, 3),
1131 ],
1132 0,
1133 )
1134 .expect("create_descriptor_chain failed");
1135 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1136
1137 let other = reader.split_at(32);
1138 assert_eq!(reader.available_bytes(), 32);
1139 assert_eq!(other.available_bytes(), 96);
1140 }
1141
1142 #[test]
split_middle()1143 fn split_middle() {
1144 use DescriptorType::*;
1145
1146 let memory_start_addr = GuestAddress(0x0);
1147 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1148
1149 let chain = create_descriptor_chain(
1150 &memory,
1151 GuestAddress(0x0),
1152 GuestAddress(0x100),
1153 vec![
1154 (Readable, 16),
1155 (Readable, 16),
1156 (Readable, 96),
1157 (Writable, 64),
1158 (Writable, 1),
1159 (Writable, 3),
1160 ],
1161 0,
1162 )
1163 .expect("create_descriptor_chain failed");
1164 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1165
1166 let other = reader.split_at(24);
1167 assert_eq!(reader.available_bytes(), 24);
1168 assert_eq!(other.available_bytes(), 104);
1169 }
1170
1171 #[test]
split_end()1172 fn split_end() {
1173 use DescriptorType::*;
1174
1175 let memory_start_addr = GuestAddress(0x0);
1176 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1177
1178 let chain = create_descriptor_chain(
1179 &memory,
1180 GuestAddress(0x0),
1181 GuestAddress(0x100),
1182 vec![
1183 (Readable, 16),
1184 (Readable, 16),
1185 (Readable, 96),
1186 (Writable, 64),
1187 (Writable, 1),
1188 (Writable, 3),
1189 ],
1190 0,
1191 )
1192 .expect("create_descriptor_chain failed");
1193 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1194
1195 let other = reader.split_at(128);
1196 assert_eq!(reader.available_bytes(), 128);
1197 assert_eq!(other.available_bytes(), 0);
1198 }
1199
1200 #[test]
split_beginning()1201 fn split_beginning() {
1202 use DescriptorType::*;
1203
1204 let memory_start_addr = GuestAddress(0x0);
1205 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1206
1207 let chain = create_descriptor_chain(
1208 &memory,
1209 GuestAddress(0x0),
1210 GuestAddress(0x100),
1211 vec![
1212 (Readable, 16),
1213 (Readable, 16),
1214 (Readable, 96),
1215 (Writable, 64),
1216 (Writable, 1),
1217 (Writable, 3),
1218 ],
1219 0,
1220 )
1221 .expect("create_descriptor_chain failed");
1222 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1223
1224 let other = reader.split_at(0);
1225 assert_eq!(reader.available_bytes(), 0);
1226 assert_eq!(other.available_bytes(), 128);
1227 }
1228
1229 #[test]
split_outofbounds()1230 fn split_outofbounds() {
1231 use DescriptorType::*;
1232
1233 let memory_start_addr = GuestAddress(0x0);
1234 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1235
1236 let chain = create_descriptor_chain(
1237 &memory,
1238 GuestAddress(0x0),
1239 GuestAddress(0x100),
1240 vec![
1241 (Readable, 16),
1242 (Readable, 16),
1243 (Readable, 96),
1244 (Writable, 64),
1245 (Writable, 1),
1246 (Writable, 3),
1247 ],
1248 0,
1249 )
1250 .expect("create_descriptor_chain failed");
1251 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1252
1253 let other = reader.split_at(256);
1254 assert_eq!(
1255 other.available_bytes(),
1256 0,
1257 "Reader returned from out-of-bounds split still has available bytes"
1258 );
1259 }
1260
1261 #[test]
read_full()1262 fn read_full() {
1263 use DescriptorType::*;
1264
1265 let memory_start_addr = GuestAddress(0x0);
1266 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1267
1268 let chain = create_descriptor_chain(
1269 &memory,
1270 GuestAddress(0x0),
1271 GuestAddress(0x100),
1272 vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1273 0,
1274 )
1275 .expect("create_descriptor_chain failed");
1276 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1277
1278 let mut buf = vec![0u8; 64];
1279 assert_eq!(
1280 reader.read(&mut buf[..]).expect("failed to read to buffer"),
1281 48
1282 );
1283 }
1284
1285 #[test]
write_full()1286 fn write_full() {
1287 use DescriptorType::*;
1288
1289 let memory_start_addr = GuestAddress(0x0);
1290 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1291
1292 let chain = create_descriptor_chain(
1293 &memory,
1294 GuestAddress(0x0),
1295 GuestAddress(0x100),
1296 vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1297 0,
1298 )
1299 .expect("create_descriptor_chain failed");
1300 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1301
1302 let buf = vec![0xdeu8; 64];
1303 assert_eq!(
1304 writer.write(&buf[..]).expect("failed to write from buffer"),
1305 48
1306 );
1307 }
1308
1309 #[test]
consume_collect()1310 fn consume_collect() {
1311 use DescriptorType::*;
1312
1313 let memory_start_addr = GuestAddress(0x0);
1314 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1315 let vs: Vec<Le64> = vec![
1316 0x0101010101010101.into(),
1317 0x0202020202020202.into(),
1318 0x0303030303030303.into(),
1319 ];
1320
1321 let write_chain = create_descriptor_chain(
1322 &memory,
1323 GuestAddress(0x0),
1324 GuestAddress(0x100),
1325 vec![(Writable, 24)],
1326 0,
1327 )
1328 .expect("create_descriptor_chain failed");
1329 let mut writer = Writer::new(memory.clone(), write_chain).expect("failed to create Writer");
1330 writer
1331 .consume(vs.clone())
1332 .expect("failed to consume() a vector");
1333
1334 let read_chain = create_descriptor_chain(
1335 &memory,
1336 GuestAddress(0x0),
1337 GuestAddress(0x100),
1338 vec![(Readable, 24)],
1339 0,
1340 )
1341 .expect("create_descriptor_chain failed");
1342 let mut reader = Reader::new(memory.clone(), read_chain).expect("failed to create Reader");
1343 let vs_read = reader
1344 .collect::<io::Result<Vec<Le64>>, _>()
1345 .expect("failed to collect() values");
1346 assert_eq!(vs, vs_read);
1347 }
1348
1349 #[test]
get_remaining_region_with_count()1350 fn get_remaining_region_with_count() {
1351 use DescriptorType::*;
1352
1353 let memory_start_addr = GuestAddress(0x0);
1354 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1355
1356 let chain = create_descriptor_chain(
1357 &memory,
1358 GuestAddress(0x0),
1359 GuestAddress(0x100),
1360 vec![
1361 (Readable, 16),
1362 (Readable, 16),
1363 (Readable, 96),
1364 (Writable, 64),
1365 (Writable, 1),
1366 (Writable, 3),
1367 ],
1368 0,
1369 )
1370 .expect("create_descriptor_chain failed");
1371
1372 let Reader {
1373 mem: _,
1374 mut regions,
1375 } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1376
1377 let drain = regions
1378 .get_remaining_regions_with_count(::std::usize::MAX)
1379 .iter()
1380 .fold(0usize, |total, region| total + region.len);
1381 assert_eq!(drain, 128);
1382
1383 let exact = regions
1384 .get_remaining_regions_with_count(32)
1385 .iter()
1386 .fold(0usize, |total, region| total + region.len);
1387 assert!(exact > 0);
1388 assert!(exact <= 32);
1389
1390 let split = regions
1391 .get_remaining_regions_with_count(24)
1392 .iter()
1393 .fold(0usize, |total, region| total + region.len);
1394 assert!(split > 0);
1395 assert!(split <= 24);
1396
1397 regions.consume(64);
1398
1399 let first = regions
1400 .get_remaining_regions_with_count(8)
1401 .iter()
1402 .fold(0usize, |total, region| total + region.len);
1403 assert!(first > 0);
1404 assert!(first <= 8);
1405 }
1406
1407 #[test]
get_remaining_with_count()1408 fn get_remaining_with_count() {
1409 use DescriptorType::*;
1410
1411 let memory_start_addr = GuestAddress(0x0);
1412 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1413
1414 let chain = create_descriptor_chain(
1415 &memory,
1416 GuestAddress(0x0),
1417 GuestAddress(0x100),
1418 vec![
1419 (Readable, 16),
1420 (Readable, 16),
1421 (Readable, 96),
1422 (Writable, 64),
1423 (Writable, 1),
1424 (Writable, 3),
1425 ],
1426 0,
1427 )
1428 .expect("create_descriptor_chain failed");
1429 let Reader {
1430 mem: _,
1431 mut regions,
1432 } = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1433
1434 let drain = regions
1435 .get_remaining_with_count(&memory, ::std::usize::MAX)
1436 .iter()
1437 .fold(0usize, |total, iov| total + iov.size());
1438 assert_eq!(drain, 128);
1439
1440 let exact = regions
1441 .get_remaining_with_count(&memory, 32)
1442 .iter()
1443 .fold(0usize, |total, iov| total + iov.size());
1444 assert!(exact > 0);
1445 assert!(exact <= 32);
1446
1447 let split = regions
1448 .get_remaining_with_count(&memory, 24)
1449 .iter()
1450 .fold(0usize, |total, iov| total + iov.size());
1451 assert!(split > 0);
1452 assert!(split <= 24);
1453
1454 regions.consume(64);
1455
1456 let first = regions
1457 .get_remaining_with_count(&memory, 8)
1458 .iter()
1459 .fold(0usize, |total, iov| total + iov.size());
1460 assert!(first > 0);
1461 assert!(first <= 8);
1462 }
1463
1464 #[test]
region_reader_failing_io()1465 fn region_reader_failing_io() {
1466 let ex = Executor::new().unwrap();
1467 ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1468 }
region_reader_failing_io_async(ex: &Executor)1469 async fn region_reader_failing_io_async(ex: &Executor) {
1470 use DescriptorType::*;
1471
1472 let memory_start_addr = GuestAddress(0x0);
1473 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1474
1475 let chain = create_descriptor_chain(
1476 &memory,
1477 GuestAddress(0x0),
1478 GuestAddress(0x100),
1479 vec![(Readable, 256), (Readable, 256)],
1480 0,
1481 )
1482 .expect("create_descriptor_chain failed");
1483
1484 let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
1485
1486 // Open a file in read-only mode so writes to it to trigger an I/O error.
1487 let ro_file = File::open("/dev/zero").expect("failed to open /dev/zero");
1488 let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1489
1490 reader
1491 .read_exact_to_at_fut(&async_ro_file, 512, 0)
1492 .await
1493 .expect_err("successfully read more bytes than SharedMemory size");
1494
1495 // The write above should have failed entirely, so we end up not writing any bytes at all.
1496 assert_eq!(reader.available_bytes(), 512);
1497 assert_eq!(reader.bytes_read(), 0);
1498 }
1499
1500 #[test]
region_writer_failing_io()1501 fn region_writer_failing_io() {
1502 let ex = Executor::new().unwrap();
1503 ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1504 }
region_writer_failing_io_async(ex: &Executor)1505 async fn region_writer_failing_io_async(ex: &Executor) {
1506 use DescriptorType::*;
1507
1508 let memory_start_addr = GuestAddress(0x0);
1509 let memory = GuestMemory::new(&vec![(memory_start_addr, 0x10000)]).unwrap();
1510
1511 let chain = create_descriptor_chain(
1512 &memory,
1513 GuestAddress(0x0),
1514 GuestAddress(0x100),
1515 vec![(Writable, 256), (Writable, 256)],
1516 0,
1517 )
1518 .expect("create_descriptor_chain failed");
1519
1520 let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
1521
1522 let file = tempfile().expect("failed to create temp file");
1523
1524 file.set_len(384).unwrap();
1525 let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1526
1527 writer
1528 .write_all_from_at_fut(&async_file, 512, 0)
1529 .await
1530 .expect_err("successfully wrote more bytes than in SharedMemory");
1531
1532 assert_eq!(writer.available_bytes(), 128);
1533 assert_eq!(writer.bytes_written(), 384);
1534 }
1535 }
1536