1 use crate::{ 2 err::{io::LimitedReadError, Layer, LenError}, 3 *, 4 }; 5 6 /// Encapsulated reader with an maximum allowed read length. 7 /// 8 /// This struct is used to limit data reads by lower protocol layers 9 /// (e.g. the payload_len in an IPv6Header limits how much data should 10 /// be read by the following layers). 11 /// 12 /// An [`crate::err::LenError`] is returned as soon as more than the 13 /// maximum read len is read. 14 #[cfg(feature = "std")] 15 #[cfg_attr(docsrs, doc(cfg(feature = "std")))] 16 pub struct LimitedReader<T> { 17 /// Reader from which data will be read. 18 reader: T, 19 /// Maximum len that still can be read (on the current layer). 20 max_len: usize, 21 /// Source of the maximum length. 22 len_source: LenSource, 23 /// Layer that is currently read (used for len error). 24 layer: Layer, 25 /// Offset of the layer that is currently read (used for len error). 26 layer_offset: usize, 27 /// Len that was read on the current layer. 28 read_len: usize, 29 } 30 31 #[cfg(feature = "std")] 32 #[cfg_attr(docsrs, doc(cfg(feature = "std")))] 33 impl<T: std::io::Read + Sized> LimitedReader<T> { 34 /// Setup a new limited reader. new( reader: T, max_len: usize, len_source: LenSource, layer_offset: usize, layer: Layer, ) -> LimitedReader<T>35 pub fn new( 36 reader: T, 37 max_len: usize, 38 len_source: LenSource, 39 layer_offset: usize, 40 layer: Layer, 41 ) -> LimitedReader<T> { 42 LimitedReader { 43 reader, 44 max_len, 45 len_source, 46 layer, 47 layer_offset, 48 read_len: 0, 49 } 50 } 51 52 /// Maximum len that still can be read (on the current layer). max_len(&self) -> usize53 pub fn max_len(&self) -> usize { 54 self.max_len 55 } 56 57 /// Source of the maximum length (used for len error). len_source(&self) -> LenSource58 pub fn len_source(&self) -> LenSource { 59 self.len_source 60 } 61 62 /// Layer that is currently read (used for len error). layer(&self) -> Layer63 pub fn layer(&self) -> Layer { 64 self.layer 65 } 66 67 /// Offset of the layer that is currently read (used for len error). layer_offset(&self) -> usize68 pub fn layer_offset(&self) -> usize { 69 self.layer_offset 70 } 71 72 /// Len that was read on the current layer. read_len(&self) -> usize73 pub fn read_len(&self) -> usize { 74 self.read_len 75 } 76 77 /// Set current position as starting position for a layer. start_layer(&mut self, layer: Layer)78 pub fn start_layer(&mut self, layer: Layer) { 79 self.layer_offset += self.read_len; 80 self.max_len -= self.read_len; 81 self.read_len = 0; 82 self.layer = layer; 83 } 84 85 /// Try read the given buf length from the reader. 86 /// 87 /// Triggers an len error if the read_exact(&mut self, buf: &mut [u8]) -> Result<(), LimitedReadError>88 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), LimitedReadError> { 89 use LimitedReadError::*; 90 if self.max_len - self.read_len < buf.len() { 91 Err(Len(LenError { 92 required_len: self.read_len + buf.len(), 93 len: self.max_len, 94 len_source: self.len_source, 95 layer: self.layer, 96 layer_start_offset: self.layer_offset, 97 })) 98 } else { 99 self.reader.read_exact(buf).map_err(Io)?; 100 self.read_len += buf.len(); 101 Ok(()) 102 } 103 } 104 105 /// Consumes LimitedReader and returns the reader. take_reader(self) -> T106 pub fn take_reader(self) -> T { 107 self.reader 108 } 109 } 110 111 #[cfg(feature = "std")] 112 #[cfg_attr(docsrs, doc(cfg(feature = "std")))] 113 impl<T: core::fmt::Debug> core::fmt::Debug for LimitedReader<T> { fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result114 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 115 f.debug_struct("LimitedReader") 116 .field("reader", &self.reader) 117 .field("max_len", &self.max_len) 118 .field("len_source", &self.len_source) 119 .field("layer", &self.layer) 120 .field("layer_offset", &self.layer_offset) 121 .field("read_len", &self.read_len) 122 .finish() 123 } 124 } 125 126 #[cfg(all(test, feature = "std"))] 127 mod tests { 128 use std::format; 129 use std::io::Cursor; 130 131 use super::*; 132 133 #[test] new()134 fn new() { 135 let data = [1, 2, 3, 4]; 136 let actual = LimitedReader::new( 137 Cursor::new(&data), 138 data.len(), 139 LenSource::Slice, 140 5, 141 Layer::Ipv4Header, 142 ); 143 assert_eq!(actual.max_len, data.len()); 144 assert_eq!(actual.max_len(), data.len()); 145 assert_eq!(actual.len_source, LenSource::Slice); 146 assert_eq!(actual.len_source(), LenSource::Slice); 147 assert_eq!(actual.layer, Layer::Ipv4Header); 148 assert_eq!(actual.layer(), Layer::Ipv4Header); 149 assert_eq!(actual.layer_offset, 5); 150 assert_eq!(actual.layer_offset(), 5); 151 assert_eq!(actual.read_len, 0); 152 assert_eq!(actual.read_len(), 0); 153 } 154 155 #[test] start_layer()156 fn start_layer() { 157 let data = [1, 2, 3, 4, 5]; 158 let mut r = LimitedReader::new( 159 Cursor::new(&data), 160 data.len(), 161 LenSource::Slice, 162 6, 163 Layer::Ipv4Header, 164 ); 165 { 166 let mut read_result = [0u8; 2]; 167 r.read_exact(&mut read_result).unwrap(); 168 assert_eq!(read_result, [1, 2]); 169 } 170 r.start_layer(Layer::IpAuthHeader); 171 172 assert_eq!(r.max_len, 3); 173 assert_eq!(r.len_source, LenSource::Slice); 174 assert_eq!(r.layer, Layer::IpAuthHeader); 175 assert_eq!(r.layer_offset, 2 + 6); 176 assert_eq!(r.read_len, 0); 177 178 { 179 let mut read_result = [0u8; 4]; 180 assert_eq!( 181 r.read_exact(&mut read_result).unwrap_err().len().unwrap(), 182 LenError { 183 required_len: 4, 184 len: 3, 185 len_source: LenSource::Slice, 186 layer: Layer::IpAuthHeader, 187 layer_start_offset: 2 + 6 188 } 189 ); 190 } 191 } 192 193 #[test] read_exact()194 fn read_exact() { 195 let data = [1, 2, 3, 4, 5]; 196 let mut r = LimitedReader::new( 197 Cursor::new(&data), 198 data.len() + 1, 199 LenSource::Ipv4HeaderTotalLen, 200 10, 201 Layer::Ipv4Header, 202 ); 203 204 // normal read 205 { 206 let mut read_result = [0u8; 2]; 207 r.read_exact(&mut read_result).unwrap(); 208 assert_eq!(read_result, [1, 2]); 209 } 210 211 // len error 212 { 213 let mut read_result = [0u8; 5]; 214 assert_eq!( 215 r.read_exact(&mut read_result).unwrap_err().len().unwrap(), 216 LenError { 217 required_len: 7, 218 len: 6, 219 len_source: LenSource::Ipv4HeaderTotalLen, 220 layer: Layer::Ipv4Header, 221 layer_start_offset: 10 222 } 223 ); 224 } 225 226 // io error 227 { 228 let mut read_result = [0u8; 4]; 229 assert!(r.read_exact(&mut read_result).unwrap_err().io().is_some()); 230 } 231 } 232 233 #[test] take_reader()234 fn take_reader() { 235 let data = [1, 2, 3, 4, 5]; 236 let mut r = LimitedReader::new( 237 Cursor::new(&data), 238 data.len(), 239 LenSource::Slice, 240 6, 241 Layer::Ipv4Header, 242 ); 243 { 244 let mut read_result = [0u8; 2]; 245 r.read_exact(&mut read_result).unwrap(); 246 assert_eq!(read_result, [1, 2]); 247 } 248 let result = r.take_reader(); 249 assert_eq!(2, result.position()); 250 } 251 252 #[test] debug()253 fn debug() { 254 let data = [1, 2, 3, 4]; 255 let actual = LimitedReader::new( 256 Cursor::new(&data), 257 data.len(), 258 LenSource::Slice, 259 5, 260 Layer::Ipv4Header, 261 ); 262 assert_eq!( 263 format!("{:?}", actual), 264 format!( 265 "LimitedReader {{ reader: {:?}, max_len: {:?}, len_source: {:?}, layer: {:?}, layer_offset: {:?}, read_len: {:?} }}", 266 &actual.reader, 267 &actual.max_len, 268 &actual.len_source, 269 &actual.layer, 270 &actual.layer_offset, 271 &actual.read_len 272 ) 273 ); 274 } 275 } 276