• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::{
2     err::{io::LimitedReadError, Layer, LenError},
3     *,
4 };
5 
6 /// Encapsulated reader with an maximum allowed read length.
7 ///
8 /// This struct is used to limit data reads by lower protocol layers
9 /// (e.g. the payload_len in an IPv6Header limits how much data should
10 /// be read by the following layers).
11 ///
12 /// An [`crate::err::LenError`] is returned as soon as more than the
13 /// maximum read len is read.
14 #[cfg(feature = "std")]
15 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
16 pub struct LimitedReader<T> {
17     /// Reader from which data will be read.
18     reader: T,
19     /// Maximum len that still can be read (on the current layer).
20     max_len: usize,
21     /// Source of the maximum length.
22     len_source: LenSource,
23     /// Layer that is currently read (used for len error).
24     layer: Layer,
25     /// Offset of the layer that is currently read (used for len error).
26     layer_offset: usize,
27     /// Len that was read on the current layer.
28     read_len: usize,
29 }
30 
31 #[cfg(feature = "std")]
32 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
33 impl<T: std::io::Read + Sized> LimitedReader<T> {
34     /// Setup a new limited reader.
new( reader: T, max_len: usize, len_source: LenSource, layer_offset: usize, layer: Layer, ) -> LimitedReader<T>35     pub fn new(
36         reader: T,
37         max_len: usize,
38         len_source: LenSource,
39         layer_offset: usize,
40         layer: Layer,
41     ) -> LimitedReader<T> {
42         LimitedReader {
43             reader,
44             max_len,
45             len_source,
46             layer,
47             layer_offset,
48             read_len: 0,
49         }
50     }
51 
52     /// Maximum len that still can be read (on the current layer).
max_len(&self) -> usize53     pub fn max_len(&self) -> usize {
54         self.max_len
55     }
56 
57     /// Source of the maximum length (used for len error).
len_source(&self) -> LenSource58     pub fn len_source(&self) -> LenSource {
59         self.len_source
60     }
61 
62     /// Layer that is currently read (used for len error).
layer(&self) -> Layer63     pub fn layer(&self) -> Layer {
64         self.layer
65     }
66 
67     /// Offset of the layer that is currently read (used for len error).
layer_offset(&self) -> usize68     pub fn layer_offset(&self) -> usize {
69         self.layer_offset
70     }
71 
72     /// Len that was read on the current layer.
read_len(&self) -> usize73     pub fn read_len(&self) -> usize {
74         self.read_len
75     }
76 
77     /// Set current position as starting position for a layer.
start_layer(&mut self, layer: Layer)78     pub fn start_layer(&mut self, layer: Layer) {
79         self.layer_offset += self.read_len;
80         self.max_len -= self.read_len;
81         self.read_len = 0;
82         self.layer = layer;
83     }
84 
85     /// Try read the given buf length from the reader.
86     ///
87     /// Triggers an len error if the
read_exact(&mut self, buf: &mut [u8]) -> Result<(), LimitedReadError>88     pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), LimitedReadError> {
89         use LimitedReadError::*;
90         if self.max_len - self.read_len < buf.len() {
91             Err(Len(LenError {
92                 required_len: self.read_len + buf.len(),
93                 len: self.max_len,
94                 len_source: self.len_source,
95                 layer: self.layer,
96                 layer_start_offset: self.layer_offset,
97             }))
98         } else {
99             self.reader.read_exact(buf).map_err(Io)?;
100             self.read_len += buf.len();
101             Ok(())
102         }
103     }
104 
105     /// Consumes LimitedReader and returns the reader.
take_reader(self) -> T106     pub fn take_reader(self) -> T {
107         self.reader
108     }
109 }
110 
111 #[cfg(feature = "std")]
112 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
113 impl<T: core::fmt::Debug> core::fmt::Debug for LimitedReader<T> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result114     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
115         f.debug_struct("LimitedReader")
116             .field("reader", &self.reader)
117             .field("max_len", &self.max_len)
118             .field("len_source", &self.len_source)
119             .field("layer", &self.layer)
120             .field("layer_offset", &self.layer_offset)
121             .field("read_len", &self.read_len)
122             .finish()
123     }
124 }
125 
126 #[cfg(all(test, feature = "std"))]
127 mod tests {
128     use std::format;
129     use std::io::Cursor;
130 
131     use super::*;
132 
133     #[test]
new()134     fn new() {
135         let data = [1, 2, 3, 4];
136         let actual = LimitedReader::new(
137             Cursor::new(&data),
138             data.len(),
139             LenSource::Slice,
140             5,
141             Layer::Ipv4Header,
142         );
143         assert_eq!(actual.max_len, data.len());
144         assert_eq!(actual.max_len(), data.len());
145         assert_eq!(actual.len_source, LenSource::Slice);
146         assert_eq!(actual.len_source(), LenSource::Slice);
147         assert_eq!(actual.layer, Layer::Ipv4Header);
148         assert_eq!(actual.layer(), Layer::Ipv4Header);
149         assert_eq!(actual.layer_offset, 5);
150         assert_eq!(actual.layer_offset(), 5);
151         assert_eq!(actual.read_len, 0);
152         assert_eq!(actual.read_len(), 0);
153     }
154 
155     #[test]
start_layer()156     fn start_layer() {
157         let data = [1, 2, 3, 4, 5];
158         let mut r = LimitedReader::new(
159             Cursor::new(&data),
160             data.len(),
161             LenSource::Slice,
162             6,
163             Layer::Ipv4Header,
164         );
165         {
166             let mut read_result = [0u8; 2];
167             r.read_exact(&mut read_result).unwrap();
168             assert_eq!(read_result, [1, 2]);
169         }
170         r.start_layer(Layer::IpAuthHeader);
171 
172         assert_eq!(r.max_len, 3);
173         assert_eq!(r.len_source, LenSource::Slice);
174         assert_eq!(r.layer, Layer::IpAuthHeader);
175         assert_eq!(r.layer_offset, 2 + 6);
176         assert_eq!(r.read_len, 0);
177 
178         {
179             let mut read_result = [0u8; 4];
180             assert_eq!(
181                 r.read_exact(&mut read_result).unwrap_err().len().unwrap(),
182                 LenError {
183                     required_len: 4,
184                     len: 3,
185                     len_source: LenSource::Slice,
186                     layer: Layer::IpAuthHeader,
187                     layer_start_offset: 2 + 6
188                 }
189             );
190         }
191     }
192 
193     #[test]
read_exact()194     fn read_exact() {
195         let data = [1, 2, 3, 4, 5];
196         let mut r = LimitedReader::new(
197             Cursor::new(&data),
198             data.len() + 1,
199             LenSource::Ipv4HeaderTotalLen,
200             10,
201             Layer::Ipv4Header,
202         );
203 
204         // normal read
205         {
206             let mut read_result = [0u8; 2];
207             r.read_exact(&mut read_result).unwrap();
208             assert_eq!(read_result, [1, 2]);
209         }
210 
211         // len error
212         {
213             let mut read_result = [0u8; 5];
214             assert_eq!(
215                 r.read_exact(&mut read_result).unwrap_err().len().unwrap(),
216                 LenError {
217                     required_len: 7,
218                     len: 6,
219                     len_source: LenSource::Ipv4HeaderTotalLen,
220                     layer: Layer::Ipv4Header,
221                     layer_start_offset: 10
222                 }
223             );
224         }
225 
226         // io error
227         {
228             let mut read_result = [0u8; 4];
229             assert!(r.read_exact(&mut read_result).unwrap_err().io().is_some());
230         }
231     }
232 
233     #[test]
take_reader()234     fn take_reader() {
235         let data = [1, 2, 3, 4, 5];
236         let mut r = LimitedReader::new(
237             Cursor::new(&data),
238             data.len(),
239             LenSource::Slice,
240             6,
241             Layer::Ipv4Header,
242         );
243         {
244             let mut read_result = [0u8; 2];
245             r.read_exact(&mut read_result).unwrap();
246             assert_eq!(read_result, [1, 2]);
247         }
248         let result = r.take_reader();
249         assert_eq!(2, result.position());
250     }
251 
252     #[test]
debug()253     fn debug() {
254         let data = [1, 2, 3, 4];
255         let actual = LimitedReader::new(
256             Cursor::new(&data),
257             data.len(),
258             LenSource::Slice,
259             5,
260             Layer::Ipv4Header,
261         );
262         assert_eq!(
263             format!("{:?}", actual),
264             format!(
265                 "LimitedReader {{ reader: {:?}, max_len: {:?}, len_source: {:?}, layer: {:?}, layer_offset: {:?}, read_len: {:?} }}",
266                 &actual.reader,
267                 &actual.max_len,
268                 &actual.len_source,
269                 &actual.layer,
270                 &actual.layer_offset,
271                 &actual.read_len
272             )
273         );
274     }
275 }
276