1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 use core::cmp::min;
16
17 use paste::paste;
18 use pw_status::{Error, Result};
19 use pw_varint::{VarintDecode, VarintEncode};
20
21 use super::{Read, Seek, SeekFrom, Write};
22
23 /// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
24 /// [`Read`], [`Write`], and [`Seek`].
25 ///
26 /// [`Write`] support requires the inner type also implement
27 /// <code>[AsMut]<[u8]></code>.
28 pub struct Cursor<T>
29 where
30 T: AsRef<[u8]>,
31 {
32 inner: T,
33 pos: usize,
34 }
35
36 impl<T: AsRef<[u8]>> Cursor<T> {
37 /// Create a new Cursor wrapping `inner` with an initial position of 0.
38 ///
39 /// Semantics match [`std::io::Cursor::new()`].
new(inner: T) -> Self40 pub fn new(inner: T) -> Self {
41 Self { inner, pos: 0 }
42 }
43
44 /// Consumes the cursor and returns the inner wrapped data.
into_inner(self) -> T45 pub fn into_inner(self) -> T {
46 self.inner
47 }
48
49 /// Returns the number of remaining bytes in the Cursor.
remaining(&self) -> usize50 pub fn remaining(&self) -> usize {
51 self.len() - self.pos
52 }
53
54 /// Returns the total length of the Cursor.
len(&self) -> usize55 pub fn len(&self) -> usize {
56 self.inner.as_ref().len()
57 }
58
59 /// Returns current IO position of the Cursor.
position(&self) -> usize60 pub fn position(&self) -> usize {
61 self.pos
62 }
63
remaining_slice(&mut self) -> &[u8]64 fn remaining_slice(&mut self) -> &[u8] {
65 &self.inner.as_ref()[self.pos..]
66 }
67 }
68
69 impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
remaining_mut(&mut self) -> &mut [u8]70 fn remaining_mut(&mut self) -> &mut [u8] {
71 &mut self.inner.as_mut()[self.pos..]
72 }
73 }
74
75 // Implement `read()` as a concrete function to avoid extra monomorphization
76 // overhead.
read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize>77 fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
78 let remaining = inner.len() - *pos;
79 let read_len = min(remaining, buf.len());
80 buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
81 *pos += read_len;
82 Ok(read_len)
83 }
84
85 impl<T: AsRef<[u8]>> Read for Cursor<T> {
read(&mut self, buf: &mut [u8]) -> Result<usize>86 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
87 read_impl(self.inner.as_ref(), &mut self.pos, buf)
88 }
89 }
90
91 // Implement `write()` as a concrete function to avoid extra monomorphization
92 // overhead.
write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize>93 fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
94 let remaining = inner.len() - *pos;
95 let write_len = min(remaining, buf.len());
96 inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
97 *pos += write_len;
98 Ok(write_len)
99 }
100
101 impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
write(&mut self, buf: &[u8]) -> Result<usize>102 fn write(&mut self, buf: &[u8]) -> Result<usize> {
103 write_impl(self.inner.as_mut(), &mut self.pos, buf)
104 }
105
flush(&mut self) -> Result<()>106 fn flush(&mut self) -> Result<()> {
107 // Cursor does not provide any buffering so flush() is a noop.
108 Ok(())
109 }
110 }
111
112 impl<T: AsRef<[u8]>> Seek for Cursor<T> {
seek(&mut self, pos: SeekFrom) -> Result<u64>113 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
114 let new_pos = match pos {
115 SeekFrom::Start(pos) => pos,
116 SeekFrom::Current(pos) => (self.pos as u64)
117 .checked_add_signed(pos)
118 .ok_or(Error::OutOfRange)?,
119 SeekFrom::End(pos) => (self.len() as u64)
120 .checked_add_signed(-pos)
121 .ok_or(Error::OutOfRange)?,
122 };
123
124 // Since Cursor operates on in memory buffers, it's limited by usize.
125 // Return an error if we are asked to seek beyond that limit.
126 let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
127
128 if new_pos > self.len() {
129 Err(Error::OutOfRange)
130 } else {
131 self.pos = new_pos;
132 Ok(new_pos as u64)
133 }
134 }
135
136 // Implement more efficient versions of rewind, stream_len, stream_position.
rewind(&mut self) -> Result<()>137 fn rewind(&mut self) -> Result<()> {
138 self.pos = 0;
139 Ok(())
140 }
141
stream_len(&mut self) -> Result<u64>142 fn stream_len(&mut self) -> Result<u64> {
143 Ok(self.len() as u64)
144 }
145
stream_position(&mut self) -> Result<u64>146 fn stream_position(&mut self) -> Result<u64> {
147 Ok(self.pos as u64)
148 }
149 }
150
151 macro_rules! cursor_read_type_impl {
152 ($ty:ident, $endian:ident) => {
153 paste! {
154 fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
155 const NUM_BYTES: usize = $ty::BITS as usize / 8;
156 if NUM_BYTES > self.remaining() {
157 return Err(Error::OutOfRange);
158 }
159 let sub_slice = self
160 .inner
161 .as_ref()
162 .get(self.pos..self.pos + NUM_BYTES)
163 .ok_or_else(|| Error::InvalidArgument)?;
164 // Because we are code size conscious we want an infallible way to
165 // turn `sub_slice` into a fixed sized array as opposed to using
166 // something like `.try_into()?`.
167 //
168 // Safety: We are both bounds checking and size constraining the
169 // slice in the above lines of code.
170 let sub_array: &[u8; NUM_BYTES] = unsafe { ::core::mem::transmute(sub_slice.as_ptr()) };
171 let value = $ty::[<from_ $endian _bytes>](*sub_array);
172
173 self.pos += NUM_BYTES;
174 Ok(value)
175 }
176 }
177 };
178 }
179
180 macro_rules! cursor_read_bits_impl {
181 ($bits:literal) => {
182 paste! {
183 cursor_read_type_impl!([<i $bits>], le);
184 cursor_read_type_impl!([<u $bits>], le);
185 cursor_read_type_impl!([<i $bits>], be);
186 cursor_read_type_impl!([<u $bits>], be);
187 }
188 };
189 }
190
191 macro_rules! cursor_write_type_impl {
192 ($ty:ident, $endian:ident) => {
193 paste! {
194 fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
195 const NUM_BYTES: usize = $ty::BITS as usize / 8;
196 if NUM_BYTES > self.remaining() {
197 return Err(Error::OutOfRange);
198 }
199 let value_bytes = $ty::[<to_ $endian _bytes>](*value);
200 let sub_slice = self
201 .inner
202 .as_mut()
203 .get_mut(self.pos..self.pos + NUM_BYTES)
204 .ok_or_else(|| Error::InvalidArgument)?;
205
206 sub_slice.copy_from_slice(&value_bytes[..]);
207
208 self.pos += NUM_BYTES;
209 Ok(())
210 }
211 }
212 };
213 }
214
215 macro_rules! cursor_write_bits_impl {
216 ($bits:literal) => {
217 paste! {
218 cursor_write_type_impl!([<i $bits>], le);
219 cursor_write_type_impl!([<u $bits>], le);
220 cursor_write_type_impl!([<i $bits>], be);
221 cursor_write_type_impl!([<u $bits>], be);
222 }
223 };
224 }
225
226 impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
227 cursor_read_bits_impl!(8);
228 cursor_read_bits_impl!(16);
229 cursor_read_bits_impl!(32);
230 cursor_read_bits_impl!(64);
231 cursor_read_bits_impl!(128);
232 }
233
234 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
235 cursor_write_bits_impl!(8);
236 cursor_write_bits_impl!(16);
237 cursor_write_bits_impl!(32);
238 cursor_write_bits_impl!(64);
239 cursor_write_bits_impl!(128);
240 }
241
242 impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
read_varint(&mut self) -> Result<u64>243 fn read_varint(&mut self) -> Result<u64> {
244 let (len, value) = u64::varint_decode(self.remaining_slice())?;
245 self.pos += len;
246 Ok(value)
247 }
248
read_signed_varint(&mut self) -> Result<i64>249 fn read_signed_varint(&mut self) -> Result<i64> {
250 let (len, value) = i64::varint_decode(self.remaining_slice())?;
251 self.pos += len;
252 Ok(value)
253 }
254 }
255
256 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
write_varint(&mut self, value: u64) -> Result<()>257 fn write_varint(&mut self, value: u64) -> Result<()> {
258 let encoded_len = value.varint_encode(self.remaining_mut())?;
259 self.pos += encoded_len;
260 Ok(())
261 }
262
write_signed_varint(&mut self, value: i64) -> Result<()>263 fn write_signed_varint(&mut self, value: i64) -> Result<()> {
264 let encoded_len = value.varint_encode(self.remaining_mut())?;
265 self.pos += encoded_len;
266 Ok(())
267 }
268 }
269
270 #[cfg(test)]
271 mod tests {
272 use super::*;
273 use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
274
275 #[test]
cursor_len_returns_total_bytes()276 fn cursor_len_returns_total_bytes() {
277 let cursor = Cursor {
278 inner: &[0u8; 64],
279 pos: 31,
280 };
281 assert_eq!(cursor.len(), 64);
282 }
283
284 #[test]
cursor_remaining_returns_remaining_bytes()285 fn cursor_remaining_returns_remaining_bytes() {
286 let cursor = Cursor {
287 inner: &[0u8; 64],
288 pos: 31,
289 };
290 assert_eq!(cursor.remaining(), 33);
291 }
292
293 #[test]
cursor_position_returns_current_position()294 fn cursor_position_returns_current_position() {
295 let cursor = Cursor {
296 inner: &[0u8; 64],
297 pos: 31,
298 };
299 assert_eq!(cursor.position(), 31);
300 }
301
302 #[test]
cursor_read_of_partial_buffer_reads_correct_data()303 fn cursor_read_of_partial_buffer_reads_correct_data() {
304 let mut cursor = Cursor {
305 inner: &[1, 2, 3, 4, 5, 6, 7, 8],
306 pos: 4,
307 };
308 let mut buf = [0u8; 8];
309 assert_eq!(cursor.read(&mut buf), Ok(4));
310 assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
311 }
312
313 #[test]
cursor_write_of_partial_buffer_writes_correct_data()314 fn cursor_write_of_partial_buffer_writes_correct_data() {
315 let mut cursor = Cursor {
316 inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
317 pos: 4,
318 };
319 let mut buf = [1, 2, 3, 4, 5, 6, 7, 8];
320 assert_eq!(cursor.write(&mut buf), Ok(4));
321 assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
322 }
323
324 #[test]
cursor_rewind_resets_position_to_zero()325 fn cursor_rewind_resets_position_to_zero() {
326 test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
327 }
328
329 #[test]
cursor_stream_pos_reports_correct_position()330 fn cursor_stream_pos_reports_correct_position() {
331 test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
332 }
333
334 #[test]
cursor_stream_len_reports_correct_length()335 fn cursor_stream_len_reports_correct_length() {
336 test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
337 }
338
339 macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
340 ($bits:literal) => {
341 paste! {
342 #[test]
343 fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
344 let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
345 let mut cursor = Cursor::new(&bytes);
346
347 assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
348 assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
349 assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
350 assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
351 }
352 }
353 };
354 }
355
356 macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
357 ($bits:literal) => {
358 paste! {
359 #[test]
360 fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
361 let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
362 let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
363 cursor.[<write_i $bits _le>](&values.0).unwrap();
364 cursor.[<write_u $bits _le>](&values.1).unwrap();
365 cursor.[<write_i $bits _be>](&values.2).unwrap();
366 cursor.[<write_u $bits _be>](&values.3).unwrap();
367
368 let result_bytes: Vec<u8> = cursor.into_inner().into();
369
370 assert_eq!(result_bytes, expected_bytes);
371 }
372 }
373 };
374 }
375
integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8))376 fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
377 (
378 vec![
379 0x0, // le i8
380 0x1, // le u8
381 0x2, // be i8
382 0x3, // be u8
383 ],
384 (0, 1, 2, 3),
385 )
386 }
387
388 cursor_read_n_bit_integers_unpacks_data_correctly!(8);
389 cursor_write_n_bit_integers_packs_data_correctly!(8);
390
integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16))391 fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
392 (
393 vec![
394 0x0, 0x80, // le i16
395 0x1, 0x80, // le u16
396 0x80, 0x2, // be i16
397 0x80, 0x3, // be u16
398 ],
399 (
400 i16::from_le_bytes([0x0, 0x80]),
401 0x8001,
402 i16::from_be_bytes([0x80, 0x2]),
403 0x8003,
404 ),
405 )
406 }
407
408 cursor_read_n_bit_integers_unpacks_data_correctly!(16);
409 cursor_write_n_bit_integers_packs_data_correctly!(16);
410
integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32))411 fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
412 (
413 vec![
414 0x0, 0x1, 0x2, 0x80, // le i32
415 0x3, 0x4, 0x5, 0x80, // le u32
416 0x80, 0x6, 0x7, 0x8, // be i32
417 0x80, 0x9, 0xa, 0xb, // be u32
418 ],
419 (
420 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
421 0x8005_0403,
422 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
423 0x8009_0a0b,
424 ),
425 )
426 }
427
428 cursor_read_n_bit_integers_unpacks_data_correctly!(32);
429 cursor_write_n_bit_integers_packs_data_correctly!(32);
430
integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64))431 fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
432 (
433 vec![
434 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
435 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
436 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
437 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
438 ],
439 (
440 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
441 0x800d_0c0b_0a09_0807,
442 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
443 0x8017_1819_1a1b_1c1d,
444 ),
445 )
446 }
447
448 cursor_read_n_bit_integers_unpacks_data_correctly!(64);
449 cursor_write_n_bit_integers_packs_data_correctly!(64);
450
integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128))451 fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
452 #[rustfmt::skip]
453 let val = (
454 vec![
455 // le i128
456 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
457 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
458 // le u128
459 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
460 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
461 // be i128
462 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
463 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
464 // be u128
465 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
466 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
467 ],
468 (
469 i128::from_le_bytes([
470 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
471 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
472 ]),
473 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
474 i128::from_be_bytes([
475 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
476 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
477 ]),
478 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
479 ),
480 );
481 val
482 }
483
484 cursor_read_n_bit_integers_unpacks_data_correctly!(128);
485 cursor_write_n_bit_integers_packs_data_correctly!(128);
486
487 #[test]
read_varint_unpacks_data_correctly()488 pub fn read_varint_unpacks_data_correctly() {
489 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
490 let value = cursor.read_varint().unwrap();
491 assert_eq!(value, 0xffff_fffe);
492
493 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
494 let value = cursor.read_varint().unwrap();
495 assert_eq!(value, 0xffff_ffff);
496 }
497
498 #[test]
read_signed_varint_unpacks_data_correctly()499 pub fn read_signed_varint_unpacks_data_correctly() {
500 let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
501 let value = cursor.read_signed_varint().unwrap();
502 assert_eq!(value, i32::MAX.into());
503
504 let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
505 let value = cursor.read_signed_varint().unwrap();
506 assert_eq!(value, i32::MIN.into());
507 }
508
509 #[test]
write_varint_packs_data_correctly()510 pub fn write_varint_packs_data_correctly() {
511 let mut cursor = Cursor::new(vec![0u8; 8]);
512 cursor.write_varint(0xffff_fffe).unwrap();
513 let buf = cursor.into_inner();
514 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
515
516 let mut cursor = Cursor::new(vec![0u8; 8]);
517 cursor.write_varint(0xffff_ffff).unwrap();
518 let buf = cursor.into_inner();
519 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
520 }
521
522 #[test]
write_signed_varint_packs_data_correctly()523 pub fn write_signed_varint_packs_data_correctly() {
524 let mut cursor = Cursor::new(vec![0u8; 8]);
525 cursor.write_signed_varint(i32::MAX.into()).unwrap();
526 let buf = cursor.into_inner();
527 assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
528
529 let mut cursor = Cursor::new(vec![0u8; 8]);
530 cursor.write_signed_varint(i32::MIN.into()).unwrap();
531 let buf = cursor.into_inner();
532 assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
533 }
534 }
535