1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cmp;
6 use std::io;
7 use std::io::Write;
8 use std::iter::FromIterator;
9 use std::marker::PhantomData;
10 use std::mem::size_of;
11 use std::mem::MaybeUninit;
12 use std::ptr::copy_nonoverlapping;
13 use std::sync::Arc;
14
15 use anyhow::Context;
16 use base::FileReadWriteAtVolatile;
17 use base::FileReadWriteVolatile;
18 use base::VolatileSlice;
19 use cros_async::MemRegion;
20 use cros_async::MemRegionIter;
21 use data_model::Le16;
22 use data_model::Le32;
23 use data_model::Le64;
24 use disk::AsyncDisk;
25 use smallvec::SmallVec;
26 use vm_memory::GuestAddress;
27 use vm_memory::GuestMemory;
28 use zerocopy::FromBytes;
29 use zerocopy::Immutable;
30 use zerocopy::IntoBytes;
31 use zerocopy::KnownLayout;
32
33 use super::DescriptorChain;
34 use crate::virtio::SplitDescriptorChain;
35
36 struct DescriptorChainRegions {
37 regions: SmallVec<[MemRegion; 2]>,
38
39 // Index of the current region in `regions`.
40 current_region_index: usize,
41
42 // Number of bytes consumed in the current region.
43 current_region_offset: usize,
44
45 // Total bytes consumed in the entire descriptor chain.
46 bytes_consumed: usize,
47 }
48
49 impl DescriptorChainRegions {
new(regions: SmallVec<[MemRegion; 2]>) -> Self50 fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
51 DescriptorChainRegions {
52 regions,
53 current_region_index: 0,
54 current_region_offset: 0,
55 bytes_consumed: 0,
56 }
57 }
58
available_bytes(&self) -> usize59 fn available_bytes(&self) -> usize {
60 // This is guaranteed not to overflow because the total length of the chain is checked
61 // during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
62 self.get_remaining_regions()
63 .fold(0usize, |count, region| count + region.len)
64 }
65
bytes_consumed(&self) -> usize66 fn bytes_consumed(&self) -> usize {
67 self.bytes_consumed
68 }
69
70 /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
71 /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
72 /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
73 /// to `consume` will return the same data.
get_remaining_regions(&self) -> MemRegionIter74 fn get_remaining_regions(&self) -> MemRegionIter {
75 MemRegionIter::new(&self.regions[self.current_region_index..])
76 .skip_bytes(self.current_region_offset)
77 }
78
79 /// Like `get_remaining_regions` but guarantees that the combined length of all the returned
80 /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
81 /// than `count` but will always be greater than 0 as long as there is still space left in the
82 /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter83 fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
84 MemRegionIter::new(&self.regions[self.current_region_index..])
85 .skip_bytes(self.current_region_offset)
86 .take_bytes(count)
87 }
88
89 /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
90 /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
91 /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
92 /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>93 fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
94 self.get_remaining_regions()
95 .filter_map(|region| {
96 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
97 .ok()
98 })
99 .collect()
100 }
101
102 /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
103 /// the 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>104 fn get_remaining_with_count<'mem>(
105 &self,
106 mem: &'mem GuestMemory,
107 count: usize,
108 ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
109 self.get_remaining_regions_with_count(count)
110 .filter_map(|region| {
111 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
112 .ok()
113 })
114 .collect()
115 }
116
117 /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
118 /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)119 fn consume(&mut self, mut count: usize) {
120 while let Some(region) = self.regions.get(self.current_region_index) {
121 let region_remaining = region.len - self.current_region_offset;
122 if count < region_remaining {
123 // The remaining count to consume is less than the remaining un-consumed length of
124 // the current region. Adjust the region offset without advancing to the next region
125 // and stop.
126 self.current_region_offset += count;
127 self.bytes_consumed += count;
128 return;
129 }
130
131 // The current region has been exhausted. Advance to the next region.
132 self.current_region_index += 1;
133 self.current_region_offset = 0;
134
135 self.bytes_consumed += region_remaining;
136 count -= region_remaining;
137 }
138 }
139
split_at(&mut self, offset: usize) -> DescriptorChainRegions140 fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
141 let mut other = DescriptorChainRegions {
142 regions: self.regions.clone(),
143 current_region_index: self.current_region_index,
144 current_region_offset: self.current_region_offset,
145 bytes_consumed: self.bytes_consumed,
146 };
147 other.consume(offset);
148 other.bytes_consumed = 0;
149
150 let mut rem = offset;
151 let mut end = self.current_region_index;
152 for region in &mut self.regions[self.current_region_index..] {
153 if rem <= region.len {
154 region.len = rem;
155 break;
156 }
157
158 end += 1;
159 rem -= region.len;
160 }
161
162 self.regions.truncate(end + 1);
163
164 other
165 }
166 }
167
168 /// Provides high-level interface over the sequence of memory regions
169 /// defined by readable descriptors in the descriptor chain.
170 ///
171 /// Note that virtio spec requires driver to place any device-writable
172 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
173 /// Reader will skip iterating over descriptor chain when first writable
174 /// descriptor is encountered.
175 pub struct Reader {
176 mem: GuestMemory,
177 regions: DescriptorChainRegions,
178 }
179
180 /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
181 pub struct ReaderIterator<'a, T: FromBytes> {
182 reader: &'a mut Reader,
183 phantom: PhantomData<T>,
184 }
185
186 impl<T: FromBytes> Iterator for ReaderIterator<'_, T> {
187 type Item = io::Result<T>;
188
next(&mut self) -> Option<io::Result<T>>189 fn next(&mut self) -> Option<io::Result<T>> {
190 if self.reader.available_bytes() == 0 {
191 None
192 } else {
193 Some(self.reader.read_obj())
194 }
195 }
196 }
197
198 impl Reader {
199 /// Construct a new Reader wrapper over `readable_regions`.
new_from_regions( mem: &GuestMemory, readable_regions: SmallVec<[MemRegion; 2]>, ) -> Reader200 pub fn new_from_regions(
201 mem: &GuestMemory,
202 readable_regions: SmallVec<[MemRegion; 2]>,
203 ) -> Reader {
204 Reader {
205 mem: mem.clone(),
206 regions: DescriptorChainRegions::new(readable_regions),
207 }
208 }
209
210 /// Reads an object from the descriptor chain buffer without consuming it.
peek_obj<T: FromBytes>(&self) -> io::Result<T>211 pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
212 let mut obj = MaybeUninit::uninit();
213
214 // SAFETY: We pass a valid pointer and size of `obj`.
215 let copied = unsafe {
216 copy_regions_to_mut_ptr(
217 &self.mem,
218 self.get_remaining_regions(),
219 obj.as_mut_ptr() as *mut u8,
220 size_of::<T>(),
221 )?
222 };
223 if copied != size_of::<T>() {
224 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
225 }
226
227 // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
228 // we initialized all bytes in `obj` in the copy above.
229 Ok(unsafe { obj.assume_init() })
230 }
231
232 /// Reads and consumes an object from the descriptor chain buffer.
read_obj<T: FromBytes>(&mut self) -> io::Result<T>233 pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
234 let obj = self.peek_obj::<T>()?;
235 self.consume(size_of::<T>());
236 Ok(obj)
237 }
238
239 /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
240 /// them as a collection. Returns an error if the size of the remaining data is indivisible by
241 /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C242 pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
243 self.iter().collect()
244 }
245
246 /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
247 /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
248 /// doesn't require the objects to be stored in a separate collection.
iter<T: FromBytes>(&mut self) -> ReaderIterator<T>249 pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
250 ReaderIterator {
251 reader: self,
252 phantom: PhantomData,
253 }
254 }
255
256 /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
257 /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize258 pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
259 let mut read = 0usize;
260 let mut dst = slice;
261 for src in self.get_remaining() {
262 src.copy_to_volatile_slice(dst);
263 let copied = std::cmp::min(src.size(), dst.size());
264 read += copied;
265 dst = match dst.offset(copied) {
266 Ok(v) => v,
267 Err(_) => break, // The slice is fully consumed
268 };
269 }
270 self.regions.consume(read);
271 read
272 }
273
274 /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
275 /// `cb`.
read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>( &mut self, cb: C, count: usize, ) -> usize276 pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
277 &mut self,
278 cb: C,
279 count: usize,
280 ) -> usize {
281 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
282 let written = cb(&iovs[..]);
283 self.regions.consume(written);
284 written
285 }
286
287 /// Reads data from the descriptor chain buffer into a writable object.
288 /// Returns the number of bytes read from the descriptor chain buffer.
289 /// The number of bytes read can be less than `count` if there isn't
290 /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>291 pub fn read_to<F: FileReadWriteVolatile>(
292 &mut self,
293 mut dst: F,
294 count: usize,
295 ) -> io::Result<usize> {
296 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
297 let written = dst.write_vectored_volatile(&iovs[..])?;
298 self.regions.consume(written);
299 Ok(written)
300 }
301
302 /// Reads data from the descriptor chain buffer into a File at offset `off`.
303 /// Returns the number of bytes read from the descriptor chain buffer.
304 /// The number of bytes read can be less than `count` if there isn't
305 /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, count: usize, off: u64, ) -> io::Result<usize>306 pub fn read_to_at<F: FileReadWriteAtVolatile>(
307 &mut self,
308 dst: &F,
309 count: usize,
310 off: u64,
311 ) -> io::Result<usize> {
312 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
313 let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
314 self.regions.consume(written);
315 Ok(written)
316 }
317
318 /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
319 /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>320 pub fn read_exact_to<F: FileReadWriteVolatile>(
321 &mut self,
322 mut dst: F,
323 mut count: usize,
324 ) -> io::Result<()> {
325 while count > 0 {
326 match self.read_to(&mut dst, count) {
327 Ok(0) => {
328 return Err(io::Error::new(
329 io::ErrorKind::UnexpectedEof,
330 "failed to fill whole buffer",
331 ))
332 }
333 Ok(n) => count -= n,
334 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
335 Err(e) => return Err(e),
336 }
337 }
338
339 Ok(())
340 }
341
342 /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
343 /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> io::Result<()>344 pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
345 &mut self,
346 dst: &F,
347 mut count: usize,
348 mut off: u64,
349 ) -> io::Result<()> {
350 while count > 0 {
351 match self.read_to_at(dst, count, off) {
352 Ok(0) => {
353 return Err(io::Error::new(
354 io::ErrorKind::UnexpectedEof,
355 "failed to fill whole buffer",
356 ))
357 }
358 Ok(n) => {
359 count -= n;
360 off += n as u64;
361 }
362 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
363 Err(e) => return Err(e),
364 }
365 }
366
367 Ok(())
368 }
369
370 /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
371 /// Returns the number of bytes read from the descriptor chain buffer.
372 /// The number of bytes read can be less than `count` if there isn't
373 /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>374 pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
375 &mut self,
376 dst: &F,
377 count: usize,
378 off: u64,
379 ) -> disk::Result<usize> {
380 let written = dst
381 .write_from_mem(
382 off,
383 Arc::new(self.mem.clone()),
384 self.regions.get_remaining_regions_with_count(count),
385 )
386 .await?;
387 self.regions.consume(written);
388 Ok(written)
389 }
390
391 /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
392 /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>393 pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
394 &mut self,
395 dst: &F,
396 mut count: usize,
397 mut off: u64,
398 ) -> disk::Result<()> {
399 while count > 0 {
400 let nread = self.read_to_at_fut(dst, count, off).await?;
401 if nread == 0 {
402 return Err(disk::Error::ReadingData(io::Error::new(
403 io::ErrorKind::UnexpectedEof,
404 "failed to write whole buffer",
405 )));
406 }
407 count -= nread;
408 off += nread as u64;
409 }
410
411 Ok(())
412 }
413
414 /// Returns number of bytes available for reading. May return an error if the combined
415 /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize416 pub fn available_bytes(&self) -> usize {
417 self.regions.available_bytes()
418 }
419
420 /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize421 pub fn bytes_read(&self) -> usize {
422 self.regions.bytes_consumed()
423 }
424
get_remaining_regions(&self) -> MemRegionIter425 pub fn get_remaining_regions(&self) -> MemRegionIter {
426 self.regions.get_remaining_regions()
427 }
428
429 /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
430 /// Calling this method does not actually consume any data from the `Reader` and callers should
431 /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>432 pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
433 self.regions.get_remaining(&self.mem)
434 }
435
436 /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
437 /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)438 pub fn consume(&mut self, amt: usize) {
439 self.regions.consume(amt)
440 }
441
442 /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
443 /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
444 /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
445 /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader446 pub fn split_at(&mut self, offset: usize) -> Reader {
447 Reader {
448 mem: self.mem.clone(),
449 regions: self.regions.split_at(offset),
450 }
451 }
452 }
453
454 /// Copy up to `size` bytes from `src` into `dst`.
455 ///
456 /// Returns the total number of bytes copied.
457 ///
458 /// # Safety
459 ///
460 /// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
461 ///
462 /// After the function returns, it is only safe to assume that the number of bytes indicated by the
463 /// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
464 /// that point are not initialized by this function.
copy_regions_to_mut_ptr( mem: &GuestMemory, src: MemRegionIter, dst: *mut u8, size: usize, ) -> io::Result<usize>465 unsafe fn copy_regions_to_mut_ptr(
466 mem: &GuestMemory,
467 src: MemRegionIter,
468 dst: *mut u8,
469 size: usize,
470 ) -> io::Result<usize> {
471 let mut copied = 0;
472 for src_region in src {
473 if copied >= size {
474 break;
475 }
476
477 let remaining = size - copied;
478 let count = cmp::min(remaining, src_region.len);
479
480 let vslice = mem
481 .get_slice_at_addr(GuestAddress(src_region.offset), count)
482 .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
483
484 // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
485 // the `count` calculation ensures we will write at most `size` bytes into `dst`.
486 unsafe {
487 copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
488 }
489
490 copied += count;
491 }
492
493 Ok(copied)
494 }
495
496 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>497 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
498 // SAFETY: We pass a valid pointer and size combination derived from `buf`.
499 let total = unsafe {
500 copy_regions_to_mut_ptr(
501 &self.mem,
502 self.regions.get_remaining_regions(),
503 buf.as_mut_ptr(),
504 buf.len(),
505 )?
506 };
507 self.regions.consume(total);
508 Ok(total)
509 }
510 }
511
512 /// Provides high-level interface over the sequence of memory regions
513 /// defined by writable descriptors in the descriptor chain.
514 ///
515 /// Note that virtio spec requires driver to place any device-writable
516 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
517 /// Writer will start iterating the descriptors from the first writable one and will
518 /// assume that all following descriptors are writable.
519 pub struct Writer {
520 mem: GuestMemory,
521 regions: DescriptorChainRegions,
522 }
523
524 impl Writer {
525 /// Construct a new Writer wrapper over `writable_regions`.
new_from_regions( mem: &GuestMemory, writable_regions: SmallVec<[MemRegion; 2]>, ) -> Writer526 pub fn new_from_regions(
527 mem: &GuestMemory,
528 writable_regions: SmallVec<[MemRegion; 2]>,
529 ) -> Writer {
530 Writer {
531 mem: mem.clone(),
532 regions: DescriptorChainRegions::new(writable_regions),
533 }
534 }
535
536 /// Writes an object to the descriptor chain buffer.
write_obj<T: Immutable + IntoBytes>(&mut self, val: T) -> io::Result<()>537 pub fn write_obj<T: Immutable + IntoBytes>(&mut self, val: T) -> io::Result<()> {
538 self.write_all(val.as_bytes())
539 }
540
541 /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
542 /// this doesn't require the values to be stored in an intermediate collection first. It also
543 /// allows callers to choose which elements in a collection to write, for example by using the
544 /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: Immutable + IntoBytes, I: Iterator<Item = T>>( &mut self, mut iter: I, ) -> io::Result<()>545 pub fn write_iter<T: Immutable + IntoBytes, I: Iterator<Item = T>>(
546 &mut self,
547 mut iter: I,
548 ) -> io::Result<()> {
549 iter.try_for_each(|v| self.write_obj(v))
550 }
551
552 /// Writes a collection of objects into the descriptor chain buffer.
consume<T: Immutable + IntoBytes, C: IntoIterator<Item = T>>( &mut self, vals: C, ) -> io::Result<()>553 pub fn consume<T: Immutable + IntoBytes, C: IntoIterator<Item = T>>(
554 &mut self,
555 vals: C,
556 ) -> io::Result<()> {
557 self.write_iter(vals.into_iter())
558 }
559
560 /// Returns number of bytes available for writing. May return an error if the combined
561 /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize562 pub fn available_bytes(&self) -> usize {
563 self.regions.available_bytes()
564 }
565
566 /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
567 /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize568 pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
569 let mut written = 0usize;
570 let mut src = slice;
571 for dst in self.get_remaining() {
572 src.copy_to_volatile_slice(dst);
573 let copied = std::cmp::min(src.size(), dst.size());
574 written += copied;
575 src = match src.offset(copied) {
576 Ok(v) => v,
577 Err(_) => break, // The slice is fully consumed
578 };
579 }
580 self.regions.consume(written);
581 written
582 }
583
584 /// Writes data to the descriptor chain buffer from a readable object.
585 /// Returns the number of bytes written to the descriptor chain buffer.
586 /// The number of bytes written can be less than `count` if
587 /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>588 pub fn write_from<F: FileReadWriteVolatile>(
589 &mut self,
590 mut src: F,
591 count: usize,
592 ) -> io::Result<usize> {
593 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
594 let read = src.read_vectored_volatile(&iovs[..])?;
595 self.regions.consume(read);
596 Ok(read)
597 }
598
599 /// Writes data to the descriptor chain buffer from a File at offset `off`.
600 /// Returns the number of bytes written to the descriptor chain buffer.
601 /// The number of bytes written can be less than `count` if
602 /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, count: usize, off: u64, ) -> io::Result<usize>603 pub fn write_from_at<F: FileReadWriteAtVolatile>(
604 &mut self,
605 src: &F,
606 count: usize,
607 off: u64,
608 ) -> io::Result<usize> {
609 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
610 let read = src.read_vectored_at_volatile(&iovs[..], off)?;
611 self.regions.consume(read);
612 Ok(read)
613 }
614
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>615 pub fn write_all_from<F: FileReadWriteVolatile>(
616 &mut self,
617 mut src: F,
618 mut count: usize,
619 ) -> io::Result<()> {
620 while count > 0 {
621 match self.write_from(&mut src, count) {
622 Ok(0) => {
623 return Err(io::Error::new(
624 io::ErrorKind::WriteZero,
625 "failed to write whole buffer",
626 ))
627 }
628 Ok(n) => count -= n,
629 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
630 Err(e) => return Err(e),
631 }
632 }
633
634 Ok(())
635 }
636
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> io::Result<()>637 pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
638 &mut self,
639 src: &F,
640 mut count: usize,
641 mut off: u64,
642 ) -> io::Result<()> {
643 while count > 0 {
644 match self.write_from_at(src, count, off) {
645 Ok(0) => {
646 return Err(io::Error::new(
647 io::ErrorKind::WriteZero,
648 "failed to write whole buffer",
649 ))
650 }
651 Ok(n) => {
652 count -= n;
653 off += n as u64;
654 }
655 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
656 Err(e) => return Err(e),
657 }
658 }
659 Ok(())
660 }
661 /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
662 /// Returns the number of bytes written to the descriptor chain buffer.
663 /// The number of bytes written can be less than `count` if
664 /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>665 pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
666 &mut self,
667 src: &F,
668 count: usize,
669 off: u64,
670 ) -> disk::Result<usize> {
671 let read = src
672 .read_to_mem(
673 off,
674 Arc::new(self.mem.clone()),
675 self.regions.get_remaining_regions_with_count(count),
676 )
677 .await?;
678 self.regions.consume(read);
679 Ok(read)
680 }
681
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>682 pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
683 &mut self,
684 src: &F,
685 mut count: usize,
686 mut off: u64,
687 ) -> disk::Result<()> {
688 while count > 0 {
689 let nwritten = self.write_from_at_fut(src, count, off).await?;
690 if nwritten == 0 {
691 return Err(disk::Error::WritingData(io::Error::new(
692 io::ErrorKind::UnexpectedEof,
693 "failed to write whole buffer",
694 )));
695 }
696 count -= nwritten;
697 off += nwritten as u64;
698 }
699 Ok(())
700 }
701
702 /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize703 pub fn bytes_written(&self) -> usize {
704 self.regions.bytes_consumed()
705 }
706
get_remaining_regions(&self) -> MemRegionIter707 pub fn get_remaining_regions(&self) -> MemRegionIter {
708 self.regions.get_remaining_regions()
709 }
710
711 /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
712 /// Calling this method does not actually advance the current position of the `Writer` in the
713 /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
714 /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
715 /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>716 pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
717 self.regions.get_remaining(&self.mem)
718 }
719
720 /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
721 /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)722 pub fn consume_bytes(&mut self, amt: usize) {
723 self.regions.consume(amt)
724 }
725
726 /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
727 /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
728 /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
729 /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer730 pub fn split_at(&mut self, offset: usize) -> Writer {
731 Writer {
732 mem: self.mem.clone(),
733 regions: self.regions.split_at(offset),
734 }
735 }
736 }
737
738 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>739 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
740 let mut rem = buf;
741 let mut total = 0;
742 for b in self.regions.get_remaining(&self.mem) {
743 if rem.is_empty() {
744 break;
745 }
746
747 let count = cmp::min(rem.len(), b.size());
748 // SAFETY:
749 // Safe because we have already verified that `vs` points to valid memory.
750 unsafe {
751 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
752 }
753 rem = &rem[count..];
754 total += count;
755 }
756
757 self.regions.consume(total);
758 Ok(total)
759 }
760
flush(&mut self) -> io::Result<()>761 fn flush(&mut self) -> io::Result<()> {
762 // Nothing to flush since the writes go straight into the buffer.
763 Ok(())
764 }
765 }
766
767 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
768 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
769
770 #[derive(Copy, Clone, PartialEq, Eq)]
771 pub enum DescriptorType {
772 Readable,
773 Writable,
774 }
775
776 #[derive(Copy, Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
777 #[repr(C)]
778 struct virtq_desc {
779 addr: Le64,
780 len: Le32,
781 flags: Le16,
782 next: Le16,
783 }
784
785 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> anyhow::Result<DescriptorChain>786 pub fn create_descriptor_chain(
787 memory: &GuestMemory,
788 descriptor_array_addr: GuestAddress,
789 mut buffers_start_addr: GuestAddress,
790 descriptors: Vec<(DescriptorType, u32)>,
791 spaces_between_regions: u32,
792 ) -> anyhow::Result<DescriptorChain> {
793 let descriptors_len = descriptors.len();
794 for (index, (type_, size)) in descriptors.into_iter().enumerate() {
795 let mut flags = 0;
796 if let DescriptorType::Writable = type_ {
797 flags |= VIRTQ_DESC_F_WRITE;
798 }
799 if index + 1 < descriptors_len {
800 flags |= VIRTQ_DESC_F_NEXT;
801 }
802
803 let index = index as u16;
804 let desc = virtq_desc {
805 addr: buffers_start_addr.offset().into(),
806 len: size.into(),
807 flags: flags.into(),
808 next: (index + 1).into(),
809 };
810
811 let offset = size + spaces_between_regions;
812 buffers_start_addr = buffers_start_addr
813 .checked_add(offset as u64)
814 .context("Invalid buffers_start_addr)")?;
815
816 let _ = memory.write_obj_at_addr(
817 desc,
818 descriptor_array_addr
819 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
820 .context("Invalid descriptor_array_addr")?,
821 );
822 }
823
824 let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
825 DescriptorChain::new(chain, memory, 0)
826 }
827
828 #[cfg(test)]
829 mod tests {
830 use std::fs::File;
831 use std::io::Read;
832
833 use cros_async::Executor;
834 use tempfile::tempfile;
835 use tempfile::NamedTempFile;
836
837 use super::*;
838
839 #[test]
reader_test_simple_chain()840 fn reader_test_simple_chain() {
841 use DescriptorType::*;
842
843 let memory_start_addr = GuestAddress(0x0);
844 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
845
846 let mut chain = create_descriptor_chain(
847 &memory,
848 GuestAddress(0x0),
849 GuestAddress(0x100),
850 vec![
851 (Readable, 8),
852 (Readable, 16),
853 (Readable, 18),
854 (Readable, 64),
855 ],
856 0,
857 )
858 .expect("create_descriptor_chain failed");
859 let reader = &mut chain.reader;
860 assert_eq!(reader.available_bytes(), 106);
861 assert_eq!(reader.bytes_read(), 0);
862
863 let mut buffer = [0u8; 64];
864 reader
865 .read_exact(&mut buffer)
866 .expect("read_exact should not fail here");
867
868 assert_eq!(reader.available_bytes(), 42);
869 assert_eq!(reader.bytes_read(), 64);
870
871 match reader.read(&mut buffer) {
872 Err(_) => panic!("read should not fail here"),
873 Ok(length) => assert_eq!(length, 42),
874 }
875
876 assert_eq!(reader.available_bytes(), 0);
877 assert_eq!(reader.bytes_read(), 106);
878 }
879
880 #[test]
writer_test_simple_chain()881 fn writer_test_simple_chain() {
882 use DescriptorType::*;
883
884 let memory_start_addr = GuestAddress(0x0);
885 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
886
887 let mut chain = create_descriptor_chain(
888 &memory,
889 GuestAddress(0x0),
890 GuestAddress(0x100),
891 vec![
892 (Writable, 8),
893 (Writable, 16),
894 (Writable, 18),
895 (Writable, 64),
896 ],
897 0,
898 )
899 .expect("create_descriptor_chain failed");
900 let writer = &mut chain.writer;
901 assert_eq!(writer.available_bytes(), 106);
902 assert_eq!(writer.bytes_written(), 0);
903
904 let buffer = [0; 64];
905 writer
906 .write_all(&buffer)
907 .expect("write_all should not fail here");
908
909 assert_eq!(writer.available_bytes(), 42);
910 assert_eq!(writer.bytes_written(), 64);
911
912 match writer.write(&buffer) {
913 Err(_) => panic!("write should not fail here"),
914 Ok(length) => assert_eq!(length, 42),
915 }
916
917 assert_eq!(writer.available_bytes(), 0);
918 assert_eq!(writer.bytes_written(), 106);
919 }
920
921 #[test]
reader_test_incompatible_chain()922 fn reader_test_incompatible_chain() {
923 use DescriptorType::*;
924
925 let memory_start_addr = GuestAddress(0x0);
926 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
927
928 let mut chain = create_descriptor_chain(
929 &memory,
930 GuestAddress(0x0),
931 GuestAddress(0x100),
932 vec![(Writable, 8)],
933 0,
934 )
935 .expect("create_descriptor_chain failed");
936 let reader = &mut chain.reader;
937 assert_eq!(reader.available_bytes(), 0);
938 assert_eq!(reader.bytes_read(), 0);
939
940 assert!(reader.read_obj::<u8>().is_err());
941
942 assert_eq!(reader.available_bytes(), 0);
943 assert_eq!(reader.bytes_read(), 0);
944 }
945
946 #[test]
writer_test_incompatible_chain()947 fn writer_test_incompatible_chain() {
948 use DescriptorType::*;
949
950 let memory_start_addr = GuestAddress(0x0);
951 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
952
953 let mut chain = create_descriptor_chain(
954 &memory,
955 GuestAddress(0x0),
956 GuestAddress(0x100),
957 vec![(Readable, 8)],
958 0,
959 )
960 .expect("create_descriptor_chain failed");
961 let writer = &mut chain.writer;
962 assert_eq!(writer.available_bytes(), 0);
963 assert_eq!(writer.bytes_written(), 0);
964
965 assert!(writer.write_obj(0u8).is_err());
966
967 assert_eq!(writer.available_bytes(), 0);
968 assert_eq!(writer.bytes_written(), 0);
969 }
970
971 #[test]
reader_failing_io()972 fn reader_failing_io() {
973 use DescriptorType::*;
974
975 let memory_start_addr = GuestAddress(0x0);
976 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
977
978 let mut chain = create_descriptor_chain(
979 &memory,
980 GuestAddress(0x0),
981 GuestAddress(0x100),
982 vec![(Readable, 256), (Readable, 256)],
983 0,
984 )
985 .expect("create_descriptor_chain failed");
986
987 let reader = &mut chain.reader;
988
989 // Open a file in read-only mode so writes to it to trigger an I/O error.
990 let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
991 let mut ro_file = File::open(device_file).expect("failed to open device file");
992
993 reader
994 .read_exact_to(&mut ro_file, 512)
995 .expect_err("successfully read more bytes than SharedMemory size");
996
997 // The write above should have failed entirely, so we end up not writing any bytes at all.
998 assert_eq!(reader.available_bytes(), 512);
999 assert_eq!(reader.bytes_read(), 0);
1000 }
1001
1002 #[test]
writer_failing_io()1003 fn writer_failing_io() {
1004 use DescriptorType::*;
1005
1006 let memory_start_addr = GuestAddress(0x0);
1007 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1008
1009 let mut chain = create_descriptor_chain(
1010 &memory,
1011 GuestAddress(0x0),
1012 GuestAddress(0x100),
1013 vec![(Writable, 256), (Writable, 256)],
1014 0,
1015 )
1016 .expect("create_descriptor_chain failed");
1017
1018 let writer = &mut chain.writer;
1019
1020 let mut file = tempfile().unwrap();
1021
1022 file.set_len(384).unwrap();
1023
1024 writer
1025 .write_all_from(&mut file, 512)
1026 .expect_err("successfully wrote more bytes than in SharedMemory");
1027
1028 assert_eq!(writer.available_bytes(), 128);
1029 assert_eq!(writer.bytes_written(), 384);
1030 }
1031
1032 #[test]
reader_writer_shared_chain()1033 fn reader_writer_shared_chain() {
1034 use DescriptorType::*;
1035
1036 let memory_start_addr = GuestAddress(0x0);
1037 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1038
1039 let mut chain = create_descriptor_chain(
1040 &memory,
1041 GuestAddress(0x0),
1042 GuestAddress(0x100),
1043 vec![
1044 (Readable, 16),
1045 (Readable, 16),
1046 (Readable, 96),
1047 (Writable, 64),
1048 (Writable, 1),
1049 (Writable, 3),
1050 ],
1051 0,
1052 )
1053 .expect("create_descriptor_chain failed");
1054 let reader = &mut chain.reader;
1055 let writer = &mut chain.writer;
1056
1057 assert_eq!(reader.bytes_read(), 0);
1058 assert_eq!(writer.bytes_written(), 0);
1059
1060 let mut buffer = Vec::with_capacity(200);
1061
1062 assert_eq!(
1063 reader
1064 .read_to_end(&mut buffer)
1065 .expect("read should not fail here"),
1066 128
1067 );
1068
1069 // The writable descriptors are only 68 bytes long.
1070 writer
1071 .write_all(&buffer[..68])
1072 .expect("write should not fail here");
1073
1074 assert_eq!(reader.available_bytes(), 0);
1075 assert_eq!(reader.bytes_read(), 128);
1076 assert_eq!(writer.available_bytes(), 0);
1077 assert_eq!(writer.bytes_written(), 68);
1078 }
1079
1080 #[test]
reader_writer_shattered_object()1081 fn reader_writer_shattered_object() {
1082 use DescriptorType::*;
1083
1084 let memory_start_addr = GuestAddress(0x0);
1085 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1086
1087 let secret: Le32 = 0x12345678.into();
1088
1089 // Create a descriptor chain with memory regions that are properly separated.
1090 let mut chain_writer = create_descriptor_chain(
1091 &memory,
1092 GuestAddress(0x0),
1093 GuestAddress(0x100),
1094 vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1095 123,
1096 )
1097 .expect("create_descriptor_chain failed");
1098 let writer = &mut chain_writer.writer;
1099 writer
1100 .write_obj(secret)
1101 .expect("write_obj should not fail here");
1102
1103 // Now create new descriptor chain pointing to the same memory and try to read it.
1104 let mut chain_reader = create_descriptor_chain(
1105 &memory,
1106 GuestAddress(0x0),
1107 GuestAddress(0x100),
1108 vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1109 123,
1110 )
1111 .expect("create_descriptor_chain failed");
1112 let reader = &mut chain_reader.reader;
1113 match reader.read_obj::<Le32>() {
1114 Err(_) => panic!("read_obj should not fail here"),
1115 Ok(read_secret) => assert_eq!(read_secret, secret),
1116 }
1117 }
1118
1119 #[test]
reader_unexpected_eof()1120 fn reader_unexpected_eof() {
1121 use DescriptorType::*;
1122
1123 let memory_start_addr = GuestAddress(0x0);
1124 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1125
1126 let mut chain = create_descriptor_chain(
1127 &memory,
1128 GuestAddress(0x0),
1129 GuestAddress(0x100),
1130 vec![(Readable, 256), (Readable, 256)],
1131 0,
1132 )
1133 .expect("create_descriptor_chain failed");
1134
1135 let reader = &mut chain.reader;
1136
1137 let mut buf = vec![0; 1024];
1138
1139 assert_eq!(
1140 reader
1141 .read_exact(&mut buf[..])
1142 .expect_err("read more bytes than available")
1143 .kind(),
1144 io::ErrorKind::UnexpectedEof
1145 );
1146 }
1147
1148 #[test]
split_border()1149 fn split_border() {
1150 use DescriptorType::*;
1151
1152 let memory_start_addr = GuestAddress(0x0);
1153 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1154
1155 let mut chain = create_descriptor_chain(
1156 &memory,
1157 GuestAddress(0x0),
1158 GuestAddress(0x100),
1159 vec![
1160 (Readable, 16),
1161 (Readable, 16),
1162 (Readable, 96),
1163 (Writable, 64),
1164 (Writable, 1),
1165 (Writable, 3),
1166 ],
1167 0,
1168 )
1169 .expect("create_descriptor_chain failed");
1170 let reader = &mut chain.reader;
1171
1172 let other = reader.split_at(32);
1173 assert_eq!(reader.available_bytes(), 32);
1174 assert_eq!(other.available_bytes(), 96);
1175 }
1176
1177 #[test]
split_middle()1178 fn split_middle() {
1179 use DescriptorType::*;
1180
1181 let memory_start_addr = GuestAddress(0x0);
1182 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1183
1184 let mut chain = create_descriptor_chain(
1185 &memory,
1186 GuestAddress(0x0),
1187 GuestAddress(0x100),
1188 vec![
1189 (Readable, 16),
1190 (Readable, 16),
1191 (Readable, 96),
1192 (Writable, 64),
1193 (Writable, 1),
1194 (Writable, 3),
1195 ],
1196 0,
1197 )
1198 .expect("create_descriptor_chain failed");
1199 let reader = &mut chain.reader;
1200
1201 let other = reader.split_at(24);
1202 assert_eq!(reader.available_bytes(), 24);
1203 assert_eq!(other.available_bytes(), 104);
1204 }
1205
1206 #[test]
split_end()1207 fn split_end() {
1208 use DescriptorType::*;
1209
1210 let memory_start_addr = GuestAddress(0x0);
1211 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1212
1213 let mut chain = create_descriptor_chain(
1214 &memory,
1215 GuestAddress(0x0),
1216 GuestAddress(0x100),
1217 vec![
1218 (Readable, 16),
1219 (Readable, 16),
1220 (Readable, 96),
1221 (Writable, 64),
1222 (Writable, 1),
1223 (Writable, 3),
1224 ],
1225 0,
1226 )
1227 .expect("create_descriptor_chain failed");
1228 let reader = &mut chain.reader;
1229
1230 let other = reader.split_at(128);
1231 assert_eq!(reader.available_bytes(), 128);
1232 assert_eq!(other.available_bytes(), 0);
1233 }
1234
1235 #[test]
split_beginning()1236 fn split_beginning() {
1237 use DescriptorType::*;
1238
1239 let memory_start_addr = GuestAddress(0x0);
1240 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1241
1242 let mut chain = create_descriptor_chain(
1243 &memory,
1244 GuestAddress(0x0),
1245 GuestAddress(0x100),
1246 vec![
1247 (Readable, 16),
1248 (Readable, 16),
1249 (Readable, 96),
1250 (Writable, 64),
1251 (Writable, 1),
1252 (Writable, 3),
1253 ],
1254 0,
1255 )
1256 .expect("create_descriptor_chain failed");
1257 let reader = &mut chain.reader;
1258
1259 let other = reader.split_at(0);
1260 assert_eq!(reader.available_bytes(), 0);
1261 assert_eq!(other.available_bytes(), 128);
1262 }
1263
1264 #[test]
split_outofbounds()1265 fn split_outofbounds() {
1266 use DescriptorType::*;
1267
1268 let memory_start_addr = GuestAddress(0x0);
1269 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1270
1271 let mut chain = create_descriptor_chain(
1272 &memory,
1273 GuestAddress(0x0),
1274 GuestAddress(0x100),
1275 vec![
1276 (Readable, 16),
1277 (Readable, 16),
1278 (Readable, 96),
1279 (Writable, 64),
1280 (Writable, 1),
1281 (Writable, 3),
1282 ],
1283 0,
1284 )
1285 .expect("create_descriptor_chain failed");
1286 let reader = &mut chain.reader;
1287
1288 let other = reader.split_at(256);
1289 assert_eq!(
1290 other.available_bytes(),
1291 0,
1292 "Reader returned from out-of-bounds split still has available bytes"
1293 );
1294 }
1295
1296 #[test]
read_full()1297 fn read_full() {
1298 use DescriptorType::*;
1299
1300 let memory_start_addr = GuestAddress(0x0);
1301 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1302
1303 let mut chain = create_descriptor_chain(
1304 &memory,
1305 GuestAddress(0x0),
1306 GuestAddress(0x100),
1307 vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1308 0,
1309 )
1310 .expect("create_descriptor_chain failed");
1311 let reader = &mut chain.reader;
1312
1313 let mut buf = [0u8; 64];
1314 assert_eq!(
1315 reader.read(&mut buf[..]).expect("failed to read to buffer"),
1316 48
1317 );
1318 }
1319
1320 #[test]
write_full()1321 fn write_full() {
1322 use DescriptorType::*;
1323
1324 let memory_start_addr = GuestAddress(0x0);
1325 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1326
1327 let mut chain = create_descriptor_chain(
1328 &memory,
1329 GuestAddress(0x0),
1330 GuestAddress(0x100),
1331 vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1332 0,
1333 )
1334 .expect("create_descriptor_chain failed");
1335 let writer = &mut chain.writer;
1336
1337 let buf = [0xdeu8; 64];
1338 assert_eq!(
1339 writer.write(&buf[..]).expect("failed to write from buffer"),
1340 48
1341 );
1342 }
1343
1344 #[test]
consume_collect()1345 fn consume_collect() {
1346 use DescriptorType::*;
1347
1348 let memory_start_addr = GuestAddress(0x0);
1349 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1350 let vs: Vec<Le64> = vec![
1351 0x0101010101010101.into(),
1352 0x0202020202020202.into(),
1353 0x0303030303030303.into(),
1354 ];
1355
1356 let mut write_chain = create_descriptor_chain(
1357 &memory,
1358 GuestAddress(0x0),
1359 GuestAddress(0x100),
1360 vec![(Writable, 24)],
1361 0,
1362 )
1363 .expect("create_descriptor_chain failed");
1364 let writer = &mut write_chain.writer;
1365 writer
1366 .consume(vs.clone())
1367 .expect("failed to consume() a vector");
1368
1369 let mut read_chain = create_descriptor_chain(
1370 &memory,
1371 GuestAddress(0x0),
1372 GuestAddress(0x100),
1373 vec![(Readable, 24)],
1374 0,
1375 )
1376 .expect("create_descriptor_chain failed");
1377 let reader = &mut read_chain.reader;
1378 let vs_read = reader
1379 .collect::<io::Result<Vec<Le64>>, _>()
1380 .expect("failed to collect() values");
1381 assert_eq!(vs, vs_read);
1382 }
1383
1384 #[test]
get_remaining_region_with_count()1385 fn get_remaining_region_with_count() {
1386 use DescriptorType::*;
1387
1388 let memory_start_addr = GuestAddress(0x0);
1389 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1390
1391 let chain = create_descriptor_chain(
1392 &memory,
1393 GuestAddress(0x0),
1394 GuestAddress(0x100),
1395 vec![
1396 (Readable, 16),
1397 (Readable, 16),
1398 (Readable, 96),
1399 (Writable, 64),
1400 (Writable, 1),
1401 (Writable, 3),
1402 ],
1403 0,
1404 )
1405 .expect("create_descriptor_chain failed");
1406
1407 let Reader {
1408 mem: _,
1409 mut regions,
1410 } = chain.reader;
1411
1412 let drain = regions
1413 .get_remaining_regions_with_count(usize::MAX)
1414 .fold(0usize, |total, region| total + region.len);
1415 assert_eq!(drain, 128);
1416
1417 let exact = regions
1418 .get_remaining_regions_with_count(32)
1419 .fold(0usize, |total, region| total + region.len);
1420 assert!(exact > 0);
1421 assert!(exact <= 32);
1422
1423 let split = regions
1424 .get_remaining_regions_with_count(24)
1425 .fold(0usize, |total, region| total + region.len);
1426 assert!(split > 0);
1427 assert!(split <= 24);
1428
1429 regions.consume(64);
1430
1431 let first = regions
1432 .get_remaining_regions_with_count(8)
1433 .fold(0usize, |total, region| total + region.len);
1434 assert!(first > 0);
1435 assert!(first <= 8);
1436 }
1437
1438 #[test]
get_remaining_with_count()1439 fn get_remaining_with_count() {
1440 use DescriptorType::*;
1441
1442 let memory_start_addr = GuestAddress(0x0);
1443 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1444
1445 let chain = create_descriptor_chain(
1446 &memory,
1447 GuestAddress(0x0),
1448 GuestAddress(0x100),
1449 vec![
1450 (Readable, 16),
1451 (Readable, 16),
1452 (Readable, 96),
1453 (Writable, 64),
1454 (Writable, 1),
1455 (Writable, 3),
1456 ],
1457 0,
1458 )
1459 .expect("create_descriptor_chain failed");
1460 let Reader {
1461 mem: _,
1462 mut regions,
1463 } = chain.reader;
1464
1465 let drain = regions
1466 .get_remaining_with_count(&memory, usize::MAX)
1467 .iter()
1468 .fold(0usize, |total, iov| total + iov.size());
1469 assert_eq!(drain, 128);
1470
1471 let exact = regions
1472 .get_remaining_with_count(&memory, 32)
1473 .iter()
1474 .fold(0usize, |total, iov| total + iov.size());
1475 assert!(exact > 0);
1476 assert!(exact <= 32);
1477
1478 let split = regions
1479 .get_remaining_with_count(&memory, 24)
1480 .iter()
1481 .fold(0usize, |total, iov| total + iov.size());
1482 assert!(split > 0);
1483 assert!(split <= 24);
1484
1485 regions.consume(64);
1486
1487 let first = regions
1488 .get_remaining_with_count(&memory, 8)
1489 .iter()
1490 .fold(0usize, |total, iov| total + iov.size());
1491 assert!(first > 0);
1492 assert!(first <= 8);
1493 }
1494
1495 #[test]
reader_peek_obj()1496 fn reader_peek_obj() {
1497 use DescriptorType::*;
1498
1499 let memory_start_addr = GuestAddress(0x0);
1500 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1501
1502 // Write test data to memory.
1503 memory
1504 .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
1505 .unwrap();
1506 memory
1507 .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
1508 .unwrap();
1509
1510 let mut chain_reader = create_descriptor_chain(
1511 &memory,
1512 GuestAddress(0x0),
1513 GuestAddress(0x100),
1514 vec![(Readable, 2), (Readable, 2)],
1515 0x100 - 2,
1516 )
1517 .expect("create_descriptor_chain failed");
1518 let reader = &mut chain_reader.reader;
1519
1520 // peek_obj() at the beginning of the chain should return the first object.
1521 let peek1 = reader.peek_obj::<Le16>().unwrap();
1522 assert_eq!(peek1, Le16::from(0xBEEF));
1523
1524 // peek_obj() again should return the same object, since it was not consumed.
1525 let peek2 = reader.peek_obj::<Le16>().unwrap();
1526 assert_eq!(peek2, Le16::from(0xBEEF));
1527
1528 // peek_obj() of an object spanning two descriptors should copy from both.
1529 let peek3 = reader.peek_obj::<Le32>().unwrap();
1530 assert_eq!(peek3, Le32::from(0xDEADBEEF));
1531
1532 // read_obj() should return the first object.
1533 let read1 = reader.read_obj::<Le16>().unwrap();
1534 assert_eq!(read1, Le16::from(0xBEEF));
1535
1536 // peek_obj() of a value that is larger than the rest of the chain should fail.
1537 reader
1538 .peek_obj::<Le32>()
1539 .expect_err("peek_obj past end of chain");
1540
1541 // read_obj() again should return the second object.
1542 let read2 = reader.read_obj::<Le16>().unwrap();
1543 assert_eq!(read2, Le16::from(0xDEAD));
1544
1545 // peek_obj() should fail at the end of the chain.
1546 reader
1547 .peek_obj::<Le16>()
1548 .expect_err("peek_obj past end of chain");
1549 }
1550
1551 #[test]
region_reader_failing_io()1552 fn region_reader_failing_io() {
1553 let ex = Executor::new().unwrap();
1554 ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1555 }
region_reader_failing_io_async(ex: &Executor)1556 async fn region_reader_failing_io_async(ex: &Executor) {
1557 use DescriptorType::*;
1558
1559 let memory_start_addr = GuestAddress(0x0);
1560 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1561
1562 let mut chain = create_descriptor_chain(
1563 &memory,
1564 GuestAddress(0x0),
1565 GuestAddress(0x100),
1566 vec![(Readable, 256), (Readable, 256)],
1567 0,
1568 )
1569 .expect("create_descriptor_chain failed");
1570
1571 let reader = &mut chain.reader;
1572
1573 // Open a file in read-only mode so writes to it to trigger an I/O error.
1574 let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
1575 let ro_file =
1576 File::open(named_temp_file.path()).expect("failed to open temp file read only");
1577 let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1578
1579 reader
1580 .read_exact_to_at_fut(&async_ro_file, 512, 0)
1581 .await
1582 .expect_err("successfully read more bytes than SingleFileDisk size");
1583
1584 // The write above should have failed entirely, so we end up not writing any bytes at all.
1585 assert_eq!(reader.available_bytes(), 512);
1586 assert_eq!(reader.bytes_read(), 0);
1587 }
1588
1589 #[test]
region_writer_failing_io()1590 fn region_writer_failing_io() {
1591 let ex = Executor::new().unwrap();
1592 ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1593 }
region_writer_failing_io_async(ex: &Executor)1594 async fn region_writer_failing_io_async(ex: &Executor) {
1595 use DescriptorType::*;
1596
1597 let memory_start_addr = GuestAddress(0x0);
1598 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1599
1600 let mut chain = create_descriptor_chain(
1601 &memory,
1602 GuestAddress(0x0),
1603 GuestAddress(0x100),
1604 vec![(Writable, 256), (Writable, 256)],
1605 0,
1606 )
1607 .expect("create_descriptor_chain failed");
1608
1609 let writer = &mut chain.writer;
1610
1611 let file = tempfile().expect("failed to create temp file");
1612
1613 file.set_len(384).unwrap();
1614 let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1615
1616 writer
1617 .write_all_from_at_fut(&async_file, 512, 0)
1618 .await
1619 .expect_err("successfully wrote more bytes than in SingleFileDisk");
1620
1621 assert_eq!(writer.available_bytes(), 128);
1622 assert_eq!(writer.bytes_written(), 384);
1623 }
1624 }
1625