• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 use core::cmp::min;
16 
17 use paste::paste;
18 use pw_status::{Error, Result};
19 use pw_varint::{VarintDecode, VarintEncode};
20 
21 use super::{Read, Seek, SeekFrom, Write};
22 
23 /// Wraps an <code>[AsRef]<[u8]></code> in a container implementing
24 /// [`Read`], [`Write`], and [`Seek`].
25 ///
26 /// [`Write`] support requires the inner type also implement
27 /// <code>[AsMut]<[u8]></code>.
28 pub struct Cursor<T>
29 where
30     T: AsRef<[u8]>,
31 {
32     inner: T,
33     pos: usize,
34 }
35 
36 impl<T: AsRef<[u8]>> Cursor<T> {
37     /// Create a new Cursor wrapping `inner` with an initial position of 0.
38     ///
39     /// Semantics match [`std::io::Cursor::new()`].
new(inner: T) -> Self40     pub fn new(inner: T) -> Self {
41         Self { inner, pos: 0 }
42     }
43 
44     /// Consumes the cursor and returns the inner wrapped data.
into_inner(self) -> T45     pub fn into_inner(self) -> T {
46         self.inner
47     }
48 
49     /// Returns the number of remaining bytes in the Cursor.
remaining(&self) -> usize50     pub fn remaining(&self) -> usize {
51         self.len() - self.pos
52     }
53 
54     /// Returns the total length of the Cursor.
len(&self) -> usize55     pub fn len(&self) -> usize {
56         self.inner.as_ref().len()
57     }
58 
59     /// Returns current IO position of the Cursor.
position(&self) -> usize60     pub fn position(&self) -> usize {
61         self.pos
62     }
63 
remaining_slice(&mut self) -> &[u8]64     fn remaining_slice(&mut self) -> &[u8] {
65         &self.inner.as_ref()[self.pos..]
66     }
67 }
68 
69 impl<T: AsRef<[u8]> + AsMut<[u8]>> Cursor<T> {
remaining_mut(&mut self) -> &mut [u8]70     fn remaining_mut(&mut self) -> &mut [u8] {
71         &mut self.inner.as_mut()[self.pos..]
72     }
73 }
74 
75 // Implement `read()` as a concrete function to avoid extra monomorphization
76 // overhead.
read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize>77 fn read_impl(inner: &[u8], pos: &mut usize, buf: &mut [u8]) -> Result<usize> {
78     let remaining = inner.len() - *pos;
79     let read_len = min(remaining, buf.len());
80     buf[..read_len].copy_from_slice(&inner[*pos..(*pos + read_len)]);
81     *pos += read_len;
82     Ok(read_len)
83 }
84 
85 impl<T: AsRef<[u8]>> Read for Cursor<T> {
read(&mut self, buf: &mut [u8]) -> Result<usize>86     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
87         read_impl(self.inner.as_ref(), &mut self.pos, buf)
88     }
89 }
90 
91 // Implement `write()` as a concrete function to avoid extra monomorphization
92 // overhead.
write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize>93 fn write_impl(inner: &mut [u8], pos: &mut usize, buf: &[u8]) -> Result<usize> {
94     let remaining = inner.len() - *pos;
95     let write_len = min(remaining, buf.len());
96     inner[*pos..(*pos + write_len)].copy_from_slice(&buf[0..write_len]);
97     *pos += write_len;
98     Ok(write_len)
99 }
100 
101 impl<T: AsRef<[u8]> + AsMut<[u8]>> Write for Cursor<T> {
write(&mut self, buf: &[u8]) -> Result<usize>102     fn write(&mut self, buf: &[u8]) -> Result<usize> {
103         write_impl(self.inner.as_mut(), &mut self.pos, buf)
104     }
105 
flush(&mut self) -> Result<()>106     fn flush(&mut self) -> Result<()> {
107         // Cursor does not provide any buffering so flush() is a noop.
108         Ok(())
109     }
110 }
111 
112 impl<T: AsRef<[u8]>> Seek for Cursor<T> {
seek(&mut self, pos: SeekFrom) -> Result<u64>113     fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
114         let new_pos = match pos {
115             SeekFrom::Start(pos) => pos,
116             SeekFrom::Current(pos) => (self.pos as u64)
117                 .checked_add_signed(pos)
118                 .ok_or(Error::OutOfRange)?,
119             SeekFrom::End(pos) => (self.len() as u64)
120                 .checked_add_signed(-pos)
121                 .ok_or(Error::OutOfRange)?,
122         };
123 
124         // Since Cursor operates on in memory buffers, it's limited by usize.
125         // Return an error if we are asked to seek beyond that limit.
126         let new_pos: usize = new_pos.try_into().map_err(|_| Error::OutOfRange)?;
127 
128         if new_pos > self.len() {
129             Err(Error::OutOfRange)
130         } else {
131             self.pos = new_pos;
132             Ok(new_pos as u64)
133         }
134     }
135 
136     // Implement more efficient versions of rewind, stream_len, stream_position.
rewind(&mut self) -> Result<()>137     fn rewind(&mut self) -> Result<()> {
138         self.pos = 0;
139         Ok(())
140     }
141 
stream_len(&mut self) -> Result<u64>142     fn stream_len(&mut self) -> Result<u64> {
143         Ok(self.len() as u64)
144     }
145 
stream_position(&mut self) -> Result<u64>146     fn stream_position(&mut self) -> Result<u64> {
147         Ok(self.pos as u64)
148     }
149 }
150 
151 macro_rules! cursor_read_type_impl {
152     ($ty:ident, $endian:ident) => {
153         paste! {
154           fn [<read_ $ty _ $endian>](&mut self) -> Result<$ty> {
155             const NUM_BYTES: usize = $ty::BITS as usize / 8;
156             if NUM_BYTES > self.remaining() {
157                 return Err(Error::OutOfRange);
158             }
159             let sub_slice = self
160                 .inner
161                 .as_ref()
162                 .get(self.pos..self.pos + NUM_BYTES)
163                 .ok_or_else(|| Error::InvalidArgument)?;
164             // Because we are code size conscious we want an infallible way to
165             // turn `sub_slice` into a fixed sized array as opposed to using
166             // something like `.try_into()?`.
167             //
168             // Safety:  We are both bounds checking and size constraining the
169             // slice in the above lines of code.
170             let sub_array: &[u8; NUM_BYTES] = unsafe { ::core::mem::transmute(sub_slice.as_ptr()) };
171             let value = $ty::[<from_ $endian _bytes>](*sub_array);
172 
173             self.pos += NUM_BYTES;
174             Ok(value)
175           }
176         }
177     };
178 }
179 
180 macro_rules! cursor_read_bits_impl {
181     ($bits:literal) => {
182         paste! {
183           cursor_read_type_impl!([<i $bits>], le);
184           cursor_read_type_impl!([<u $bits>], le);
185           cursor_read_type_impl!([<i $bits>], be);
186           cursor_read_type_impl!([<u $bits>], be);
187         }
188     };
189 }
190 
191 macro_rules! cursor_write_type_impl {
192     ($ty:ident, $endian:ident) => {
193         paste! {
194           fn [<write_ $ty _ $endian>](&mut self, value: &$ty) -> Result<()> {
195             const NUM_BYTES: usize = $ty::BITS as usize / 8;
196             if NUM_BYTES > self.remaining() {
197                 return Err(Error::OutOfRange);
198             }
199             let value_bytes = $ty::[<to_ $endian _bytes>](*value);
200             let sub_slice = self
201                 .inner
202                 .as_mut()
203                 .get_mut(self.pos..self.pos + NUM_BYTES)
204                 .ok_or_else(|| Error::InvalidArgument)?;
205 
206             sub_slice.copy_from_slice(&value_bytes[..]);
207 
208             self.pos += NUM_BYTES;
209             Ok(())
210           }
211         }
212     };
213 }
214 
215 macro_rules! cursor_write_bits_impl {
216     ($bits:literal) => {
217         paste! {
218           cursor_write_type_impl!([<i $bits>], le);
219           cursor_write_type_impl!([<u $bits>], le);
220           cursor_write_type_impl!([<i $bits>], be);
221           cursor_write_type_impl!([<u $bits>], be);
222         }
223     };
224 }
225 
226 impl<T: AsRef<[u8]>> crate::ReadInteger for Cursor<T> {
227     cursor_read_bits_impl!(8);
228     cursor_read_bits_impl!(16);
229     cursor_read_bits_impl!(32);
230     cursor_read_bits_impl!(64);
231     cursor_read_bits_impl!(128);
232 }
233 
234 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteInteger for Cursor<T> {
235     cursor_write_bits_impl!(8);
236     cursor_write_bits_impl!(16);
237     cursor_write_bits_impl!(32);
238     cursor_write_bits_impl!(64);
239     cursor_write_bits_impl!(128);
240 }
241 
242 impl<T: AsRef<[u8]>> crate::ReadVarint for Cursor<T> {
read_varint(&mut self) -> Result<u64>243     fn read_varint(&mut self) -> Result<u64> {
244         let (len, value) = u64::varint_decode(self.remaining_slice())?;
245         self.pos += len;
246         Ok(value)
247     }
248 
read_signed_varint(&mut self) -> Result<i64>249     fn read_signed_varint(&mut self) -> Result<i64> {
250         let (len, value) = i64::varint_decode(self.remaining_slice())?;
251         self.pos += len;
252         Ok(value)
253     }
254 }
255 
256 impl<T: AsRef<[u8]> + AsMut<[u8]>> crate::WriteVarint for Cursor<T> {
write_varint(&mut self, value: u64) -> Result<()>257     fn write_varint(&mut self, value: u64) -> Result<()> {
258         let encoded_len = value.varint_encode(self.remaining_mut())?;
259         self.pos += encoded_len;
260         Ok(())
261     }
262 
write_signed_varint(&mut self, value: i64) -> Result<()>263     fn write_signed_varint(&mut self, value: i64) -> Result<()> {
264         let encoded_len = value.varint_encode(self.remaining_mut())?;
265         self.pos += encoded_len;
266         Ok(())
267     }
268 }
269 
270 #[cfg(test)]
271 mod tests {
272     use super::*;
273     use crate::{test_utils::*, ReadInteger, ReadVarint, WriteInteger, WriteVarint};
274 
275     #[test]
cursor_len_returns_total_bytes()276     fn cursor_len_returns_total_bytes() {
277         let cursor = Cursor {
278             inner: &[0u8; 64],
279             pos: 31,
280         };
281         assert_eq!(cursor.len(), 64);
282     }
283 
284     #[test]
cursor_remaining_returns_remaining_bytes()285     fn cursor_remaining_returns_remaining_bytes() {
286         let cursor = Cursor {
287             inner: &[0u8; 64],
288             pos: 31,
289         };
290         assert_eq!(cursor.remaining(), 33);
291     }
292 
293     #[test]
cursor_position_returns_current_position()294     fn cursor_position_returns_current_position() {
295         let cursor = Cursor {
296             inner: &[0u8; 64],
297             pos: 31,
298         };
299         assert_eq!(cursor.position(), 31);
300     }
301 
302     #[test]
cursor_read_of_partial_buffer_reads_correct_data()303     fn cursor_read_of_partial_buffer_reads_correct_data() {
304         let mut cursor = Cursor {
305             inner: &[1, 2, 3, 4, 5, 6, 7, 8],
306             pos: 4,
307         };
308         let mut buf = [0u8; 8];
309         assert_eq!(cursor.read(&mut buf), Ok(4));
310         assert_eq!(buf, [5, 6, 7, 8, 0, 0, 0, 0]);
311     }
312 
313     #[test]
cursor_write_of_partial_buffer_writes_correct_data()314     fn cursor_write_of_partial_buffer_writes_correct_data() {
315         let mut cursor = Cursor {
316             inner: &mut [0, 0, 0, 0, 0, 0, 0, 0],
317             pos: 4,
318         };
319         let mut buf = [1, 2, 3, 4, 5, 6, 7, 8];
320         assert_eq!(cursor.write(&mut buf), Ok(4));
321         assert_eq!(cursor.inner, &[0, 0, 0, 0, 1, 2, 3, 4]);
322     }
323 
324     #[test]
cursor_rewind_resets_position_to_zero()325     fn cursor_rewind_resets_position_to_zero() {
326         test_rewind_resets_position_to_zero::<64, _>(Cursor::new(&[0u8; 64]));
327     }
328 
329     #[test]
cursor_stream_pos_reports_correct_position()330     fn cursor_stream_pos_reports_correct_position() {
331         test_stream_pos_reports_correct_position::<64, _>(Cursor::new(&[0u8; 64]));
332     }
333 
334     #[test]
cursor_stream_len_reports_correct_length()335     fn cursor_stream_len_reports_correct_length() {
336         test_stream_len_reports_correct_length::<64, _>(Cursor::new(&[0u8; 64]));
337     }
338 
339     macro_rules! cursor_read_n_bit_integers_unpacks_data_correctly {
340         ($bits:literal) => {
341             paste! {
342               #[test]
343               fn [<cursor_read_ $bits _bit_integers_unpacks_data_correctly>]() {
344                   let (bytes, values) = [<integer_ $bits _bit_test_cases>]();
345                   let mut cursor = Cursor::new(&bytes);
346 
347                   assert_eq!(cursor.[<read_i $bits _le>](), Ok(values.0));
348                   assert_eq!(cursor.[<read_u $bits _le>](), Ok(values.1));
349                   assert_eq!(cursor.[<read_i $bits _be>](), Ok(values.2));
350                   assert_eq!(cursor.[<read_u $bits _be>](), Ok(values.3));
351               }
352             }
353         };
354     }
355 
356     macro_rules! cursor_write_n_bit_integers_packs_data_correctly {
357         ($bits:literal) => {
358             paste! {
359               #[test]
360               fn [<cursor_write_ $bits _bit_integers_packs_data_correctly>]() {
361                   let (expected_bytes, values) = [<integer_ $bits _bit_test_cases>]();
362                   let mut cursor = Cursor::new(vec![0u8; expected_bytes.len()]);
363                   cursor.[<write_i $bits _le>](&values.0).unwrap();
364                   cursor.[<write_u $bits _le>](&values.1).unwrap();
365                   cursor.[<write_i $bits _be>](&values.2).unwrap();
366                   cursor.[<write_u $bits _be>](&values.3).unwrap();
367 
368                   let result_bytes: Vec<u8> = cursor.into_inner().into();
369 
370                   assert_eq!(result_bytes, expected_bytes);
371               }
372             }
373         };
374     }
375 
integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8))376     fn integer_8_bit_test_cases() -> (Vec<u8>, (i8, u8, i8, u8)) {
377         (
378             vec![
379                 0x0, // le i8
380                 0x1, // le u8
381                 0x2, // be i8
382                 0x3, // be u8
383             ],
384             (0, 1, 2, 3),
385         )
386     }
387 
388     cursor_read_n_bit_integers_unpacks_data_correctly!(8);
389     cursor_write_n_bit_integers_packs_data_correctly!(8);
390 
integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16))391     fn integer_16_bit_test_cases() -> (Vec<u8>, (i16, u16, i16, u16)) {
392         (
393             vec![
394                 0x0, 0x80, // le i16
395                 0x1, 0x80, // le u16
396                 0x80, 0x2, // be i16
397                 0x80, 0x3, // be u16
398             ],
399             (
400                 i16::from_le_bytes([0x0, 0x80]),
401                 0x8001,
402                 i16::from_be_bytes([0x80, 0x2]),
403                 0x8003,
404             ),
405         )
406     }
407 
408     cursor_read_n_bit_integers_unpacks_data_correctly!(16);
409     cursor_write_n_bit_integers_packs_data_correctly!(16);
410 
integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32))411     fn integer_32_bit_test_cases() -> (Vec<u8>, (i32, u32, i32, u32)) {
412         (
413             vec![
414                 0x0, 0x1, 0x2, 0x80, // le i32
415                 0x3, 0x4, 0x5, 0x80, // le u32
416                 0x80, 0x6, 0x7, 0x8, // be i32
417                 0x80, 0x9, 0xa, 0xb, // be u32
418             ],
419             (
420                 i32::from_le_bytes([0x0, 0x1, 0x2, 0x80]),
421                 0x8005_0403,
422                 i32::from_be_bytes([0x80, 0x6, 0x7, 0x8]),
423                 0x8009_0a0b,
424             ),
425         )
426     }
427 
428     cursor_read_n_bit_integers_unpacks_data_correctly!(32);
429     cursor_write_n_bit_integers_packs_data_correctly!(32);
430 
integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64))431     fn integer_64_bit_test_cases() -> (Vec<u8>, (i64, u64, i64, u64)) {
432         (
433             vec![
434                 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80, // le i64
435                 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0x80, // le u64
436                 0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, // be i64
437                 0x80, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, // be u64
438             ],
439             (
440                 i64::from_le_bytes([0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x80]),
441                 0x800d_0c0b_0a09_0807,
442                 i64::from_be_bytes([0x80, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16]),
443                 0x8017_1819_1a1b_1c1d,
444             ),
445         )
446     }
447 
448     cursor_read_n_bit_integers_unpacks_data_correctly!(64);
449     cursor_write_n_bit_integers_packs_data_correctly!(64);
450 
integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128))451     fn integer_128_bit_test_cases() -> (Vec<u8>, (i128, u128, i128, u128)) {
452         #[rustfmt::skip]
453         let val = (
454             vec![
455                 // le i128
456                 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
457                 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
458                 // le u128
459                 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
460                 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x8f,
461                 // be i128
462                 0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
463                 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
464                 // be u128
465                 0x80, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
466                 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
467             ],
468             (
469                 i128::from_le_bytes([
470                     0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
471                     0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x8f,
472                 ]),
473                 0x8f1e_1d1c_1b1a_1918_1716_1514_1312_1110,
474                 i128::from_be_bytes([
475                     0x80, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27,
476                     0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
477                 ]),
478                 0x8031_3233_3435_3637_3839_3a3b_3c3d_3e3f,
479             ),
480         );
481         val
482     }
483 
484     cursor_read_n_bit_integers_unpacks_data_correctly!(128);
485     cursor_write_n_bit_integers_packs_data_correctly!(128);
486 
487     #[test]
read_varint_unpacks_data_correctly()488     pub fn read_varint_unpacks_data_correctly() {
489         let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
490         let value = cursor.read_varint().unwrap();
491         assert_eq!(value, 0xffff_fffe);
492 
493         let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
494         let value = cursor.read_varint().unwrap();
495         assert_eq!(value, 0xffff_ffff);
496     }
497 
498     #[test]
read_signed_varint_unpacks_data_correctly()499     pub fn read_signed_varint_unpacks_data_correctly() {
500         let mut cursor = Cursor::new(vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
501         let value = cursor.read_signed_varint().unwrap();
502         assert_eq!(value, i32::MAX.into());
503 
504         let mut cursor = Cursor::new(vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
505         let value = cursor.read_signed_varint().unwrap();
506         assert_eq!(value, i32::MIN.into());
507     }
508 
509     #[test]
write_varint_packs_data_correctly()510     pub fn write_varint_packs_data_correctly() {
511         let mut cursor = Cursor::new(vec![0u8; 8]);
512         cursor.write_varint(0xffff_fffe).unwrap();
513         let buf = cursor.into_inner();
514         assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
515 
516         let mut cursor = Cursor::new(vec![0u8; 8]);
517         cursor.write_varint(0xffff_ffff).unwrap();
518         let buf = cursor.into_inner();
519         assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
520     }
521 
522     #[test]
write_signed_varint_packs_data_correctly()523     pub fn write_signed_varint_packs_data_correctly() {
524         let mut cursor = Cursor::new(vec![0u8; 8]);
525         cursor.write_signed_varint(i32::MAX.into()).unwrap();
526         let buf = cursor.into_inner();
527         assert_eq!(buf, vec![0xfe, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
528 
529         let mut cursor = Cursor::new(vec![0u8; 8]);
530         cursor.write_signed_varint(i32::MIN.into()).unwrap();
531         let buf = cursor.into_inner();
532         assert_eq!(buf, vec![0xff, 0xff, 0xff, 0xff, 0x0f, 0x0, 0x0, 0x0]);
533     }
534 }
535