1 use std::cmp; 2 use std::io::BufRead; 3 use std::io::BufReader; 4 use std::io::Read; 5 use std::mem; 6 use std::mem::MaybeUninit; 7 8 #[cfg(feature = "bytes")] 9 use bytes::buf::UninitSlice; 10 #[cfg(feature = "bytes")] 11 use bytes::BufMut; 12 #[cfg(feature = "bytes")] 13 use bytes::Bytes; 14 #[cfg(feature = "bytes")] 15 use bytes::BytesMut; 16 17 use crate::coded_input_stream::buf_read_or_reader::BufReadOrReader; 18 use crate::coded_input_stream::input_buf::InputBuf; 19 use crate::coded_input_stream::input_source::InputSource; 20 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC; 21 use crate::error::ProtobufError; 22 use crate::error::WireError; 23 24 // If an input stream is constructed with a `Read`, we create a 25 // `BufReader` with an internal buffer of this size. 26 const INPUT_STREAM_BUFFER_SIZE: usize = 4096; 27 28 const NO_LIMIT: u64 = u64::MAX; 29 30 /// Dangerous implementation of `BufRead`. 31 /// 32 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is 33 /// not moved when `BufRead` is moved. 34 /// 35 /// This assumption is generally incorrect, however, in practice 36 /// `BufReadIter` is created either from `BufRead` reference (which 37 /// cannot be moved, because it is locked by `CodedInputStream`) or from 38 /// `BufReader` which does not move its buffer (we know that from 39 /// inspecting rust standard library). 40 /// 41 /// It is important for `CodedInputStream` performance that small reads 42 /// (e. g. 4 bytes reads) do not involve virtual calls or switches. 43 /// This is achievable with `BufReadIter`. 44 #[derive(Debug)] 45 pub(crate) struct BufReadIter<'a> { 46 input_source: InputSource<'a>, 47 buf: InputBuf<'a>, 48 pos_of_buf_start: u64, 49 limit: u64, 50 } 51 52 impl<'a> Drop for BufReadIter<'a> { drop(&mut self)53 fn drop(&mut self) { 54 match self.input_source { 55 InputSource::Read(ref mut buf_read) => buf_read.consume(self.buf.pos_within_buf()), 56 _ => {} 57 } 58 } 59 } 60 61 impl<'a> BufReadIter<'a> { from_read(read: &'a mut dyn Read) -> BufReadIter<'a>62 pub(crate) fn from_read(read: &'a mut dyn Read) -> BufReadIter<'a> { 63 BufReadIter { 64 input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity( 65 INPUT_STREAM_BUFFER_SIZE, 66 read, 67 ))), 68 buf: InputBuf::empty(), 69 pos_of_buf_start: 0, 70 limit: NO_LIMIT, 71 } 72 } 73 from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>74 pub(crate) fn from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { 75 BufReadIter { 76 input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)), 77 buf: InputBuf::empty(), 78 pos_of_buf_start: 0, 79 limit: NO_LIMIT, 80 } 81 } 82 from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a>83 pub(crate) fn from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a> { 84 BufReadIter { 85 input_source: InputSource::Slice(bytes), 86 buf: InputBuf::from_bytes(bytes), 87 pos_of_buf_start: 0, 88 limit: NO_LIMIT, 89 } 90 } 91 92 #[cfg(feature = "bytes")] from_bytes(bytes: &'a Bytes) -> BufReadIter<'a>93 pub(crate) fn from_bytes(bytes: &'a Bytes) -> BufReadIter<'a> { 94 BufReadIter { 95 input_source: InputSource::Bytes(bytes), 96 buf: InputBuf::from_bytes(&bytes), 97 pos_of_buf_start: 0, 98 limit: NO_LIMIT, 99 } 100 } 101 102 #[inline] assertions(&self)103 fn assertions(&self) { 104 debug_assert!(self.pos() <= self.limit); 105 self.buf.assertions(); 106 } 107 108 #[inline(always)] pos(&self) -> u64109 pub(crate) fn pos(&self) -> u64 { 110 self.pos_of_buf_start + self.buf.pos_within_buf() as u64 111 } 112 113 /// Recompute `limit_within_buf` after update of `limit` 114 #[inline] update_limit_within_buf(&mut self)115 fn update_limit_within_buf(&mut self) { 116 assert!(self.limit >= self.pos_of_buf_start); 117 self.buf.update_limit(self.limit - self.pos_of_buf_start); 118 self.assertions(); 119 } 120 push_limit(&mut self, limit: u64) -> crate::Result<u64>121 pub(crate) fn push_limit(&mut self, limit: u64) -> crate::Result<u64> { 122 let new_limit = match self.pos().checked_add(limit) { 123 Some(new_limit) => new_limit, 124 None => return Err(ProtobufError::WireError(WireError::LimitOverflow).into()), 125 }; 126 127 if new_limit > self.limit { 128 return Err(ProtobufError::WireError(WireError::LimitIncrease).into()); 129 } 130 131 let prev_limit = mem::replace(&mut self.limit, new_limit); 132 133 self.update_limit_within_buf(); 134 135 Ok(prev_limit) 136 } 137 138 #[inline] pop_limit(&mut self, limit: u64)139 pub(crate) fn pop_limit(&mut self, limit: u64) { 140 assert!(limit >= self.limit); 141 142 self.limit = limit; 143 144 self.update_limit_within_buf(); 145 } 146 147 #[inline(always)] remaining_in_buf(&self) -> &[u8]148 pub(crate) fn remaining_in_buf(&self) -> &[u8] { 149 self.buf.remaining_in_buf() 150 } 151 152 #[inline] consume(&mut self, amt: usize)153 pub(crate) fn consume(&mut self, amt: usize) { 154 self.buf.consume(amt); 155 } 156 157 #[inline(always)] remaining_in_buf_len(&self) -> usize158 pub(crate) fn remaining_in_buf_len(&self) -> usize { 159 self.remaining_in_buf().len() 160 } 161 162 #[inline(always)] bytes_until_limit(&self) -> u64163 pub(crate) fn bytes_until_limit(&self) -> u64 { 164 if self.limit == NO_LIMIT { 165 NO_LIMIT 166 } else { 167 self.limit - self.pos() 168 } 169 } 170 171 #[inline(always)] eof(&mut self) -> crate::Result<bool>172 pub(crate) fn eof(&mut self) -> crate::Result<bool> { 173 if self.remaining_in_buf_len() != 0 { 174 Ok(false) 175 } else { 176 Ok(self.fill_buf()?.is_empty()) 177 } 178 } 179 read_byte_slow(&mut self) -> crate::Result<u8>180 fn read_byte_slow(&mut self) -> crate::Result<u8> { 181 self.fill_buf_slow()?; 182 183 if let Some(b) = self.buf.read_byte() { 184 return Ok(b); 185 } 186 187 Err(WireError::UnexpectedEof.into()) 188 } 189 190 #[inline(always)] read_byte(&mut self) -> crate::Result<u8>191 pub(crate) fn read_byte(&mut self) -> crate::Result<u8> { 192 if let Some(b) = self.buf.read_byte() { 193 return Ok(b); 194 } 195 196 self.read_byte_slow() 197 } 198 199 #[cfg(feature = "bytes")] read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes>200 pub(crate) fn read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes> { 201 if let InputSource::Bytes(bytes) = self.input_source { 202 if len > self.remaining_in_buf_len() { 203 return Err(ProtobufError::WireError(WireError::UnexpectedEof).into()); 204 } 205 let end = self.buf.pos_within_buf() + len; 206 207 let r = bytes.slice(self.buf.pos_within_buf()..end); 208 self.buf.consume(len); 209 Ok(r) 210 } else { 211 if len >= READ_RAW_BYTES_MAX_ALLOC { 212 // We cannot trust `len` because protobuf message could be malformed. 213 // Reading should not result in OOM when allocating a buffer. 214 let mut v = Vec::new(); 215 self.read_exact_to_vec(len, &mut v)?; 216 Ok(Bytes::from(v)) 217 } else { 218 let mut r = BytesMut::with_capacity(len); 219 unsafe { 220 let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]); 221 self.read_exact(buf)?; 222 r.advance_mut(len); 223 } 224 Ok(r.freeze()) 225 } 226 } 227 } 228 229 #[cfg(feature = "bytes")] uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>]230 unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] { 231 use std::slice; 232 slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len()) 233 } 234 235 /// Returns 0 when EOF or limit reached. read(&mut self, buf: &mut [u8]) -> crate::Result<usize>236 pub(crate) fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> { 237 let rem = self.fill_buf()?; 238 239 let len = cmp::min(rem.len(), buf.len()); 240 buf[..len].copy_from_slice(&rem[..len]); 241 self.buf.consume(len); 242 Ok(len) 243 } 244 consume_buf(&mut self) -> crate::Result<()>245 fn consume_buf(&mut self) -> crate::Result<()> { 246 match &mut self.input_source { 247 InputSource::Read(read) => { 248 read.consume(self.buf.pos_within_buf()); 249 self.pos_of_buf_start += self.buf.pos_within_buf() as u64; 250 self.buf = InputBuf::empty(); 251 self.assertions(); 252 Ok(()) 253 } 254 _ => Err(WireError::UnexpectedEof.into()), 255 } 256 } 257 258 /// Read at most `max` bytes. 259 /// 260 /// Returns 0 when EOF or limit reached. read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize>261 fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize> { 262 let rem = self.fill_buf()?; 263 264 let len = cmp::min(rem.len(), max); 265 vec.extend_from_slice(&rem[..len]); 266 self.buf.consume(len); 267 Ok(len) 268 } 269 read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()>270 fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> { 271 if self.bytes_until_limit() < buf.len() as u64 { 272 return Err(ProtobufError::WireError(WireError::UnexpectedEof).into()); 273 } 274 275 self.consume_buf()?; 276 277 match &mut self.input_source { 278 InputSource::Read(buf_read) => { 279 buf_read.read_exact_uninit(buf)?; 280 self.pos_of_buf_start += buf.len() as u64; 281 self.assertions(); 282 Ok(()) 283 } 284 _ => unreachable!(), 285 } 286 } 287 288 #[inline] read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()>289 pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> { 290 if self.remaining_in_buf_len() >= buf.len() { 291 self.buf.read_bytes(buf); 292 return Ok(()); 293 } 294 295 self.read_exact_slow(buf) 296 } 297 298 /// Read raw bytes into the supplied vector. The vector will be resized as needed and 299 /// overwritten. read_exact_to_vec( &mut self, count: usize, target: &mut Vec<u8>, ) -> crate::Result<()>300 pub(crate) fn read_exact_to_vec( 301 &mut self, 302 count: usize, 303 target: &mut Vec<u8>, 304 ) -> crate::Result<()> { 305 // TODO: also do some limits when reading from unlimited source 306 if count as u64 > self.bytes_until_limit() { 307 return Err(ProtobufError::WireError(WireError::TruncatedMessage).into()); 308 } 309 310 target.clear(); 311 312 if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { 313 // avoid calling `reserve` on buf with very large buffer: could be a malformed message 314 315 target.reserve(READ_RAW_BYTES_MAX_ALLOC); 316 317 while target.len() < count { 318 if count - target.len() <= target.len() { 319 target.reserve_exact(count - target.len()); 320 } else { 321 target.reserve(1); 322 } 323 324 let max = cmp::min(target.capacity() - target.len(), count - target.len()); 325 let read = self.read_to_vec(target, max)?; 326 if read == 0 { 327 return Err(ProtobufError::WireError(WireError::TruncatedMessage).into()); 328 } 329 } 330 } else { 331 target.reserve_exact(count); 332 333 unsafe { 334 self.read_exact(&mut target.spare_capacity_mut()[..count])?; 335 target.set_len(count); 336 } 337 } 338 339 debug_assert_eq!(count, target.len()); 340 341 Ok(()) 342 } 343 skip_bytes(&mut self, count: u32) -> crate::Result<()>344 pub(crate) fn skip_bytes(&mut self, count: u32) -> crate::Result<()> { 345 if count as usize <= self.remaining_in_buf_len() { 346 self.buf.consume(count as usize); 347 return Ok(()); 348 } 349 350 if count as u64 > self.bytes_until_limit() { 351 return Err(WireError::TruncatedMessage.into()); 352 } 353 354 self.consume_buf()?; 355 356 match &mut self.input_source { 357 InputSource::Read(read) => { 358 read.skip_bytes(count as usize)?; 359 self.pos_of_buf_start += count as u64; 360 self.assertions(); 361 Ok(()) 362 } 363 _ => unreachable!(), 364 } 365 } 366 fill_buf_slow(&mut self) -> crate::Result<()>367 fn fill_buf_slow(&mut self) -> crate::Result<()> { 368 self.assertions(); 369 if self.limit == self.pos() { 370 return Ok(()); 371 } 372 373 match self.input_source { 374 InputSource::Read(..) => {} 375 _ => return Ok(()), 376 } 377 378 self.consume_buf()?; 379 380 match self.input_source { 381 InputSource::Read(ref mut buf_read) => { 382 self.buf = unsafe { InputBuf::from_bytes_ignore_lifetime(buf_read.fill_buf()?) }; 383 self.update_limit_within_buf(); 384 Ok(()) 385 } 386 _ => { 387 unreachable!(); 388 } 389 } 390 } 391 392 #[inline(always)] fill_buf(&mut self) -> crate::Result<&[u8]>393 pub(crate) fn fill_buf(&mut self) -> crate::Result<&[u8]> { 394 let rem = self.buf.remaining_in_buf(); 395 if !rem.is_empty() { 396 return Ok(rem); 397 } 398 399 if self.limit == self.pos() { 400 return Ok(&[]); 401 } 402 403 self.fill_buf_slow()?; 404 405 Ok(self.buf.remaining_in_buf()) 406 } 407 } 408 409 #[cfg(all(test, feature = "bytes"))] 410 mod test_bytes { 411 use std::io::Write; 412 413 use super::*; 414 make_long_string(len: usize) -> Vec<u8>415 fn make_long_string(len: usize) -> Vec<u8> { 416 let mut s = Vec::new(); 417 while s.len() < len { 418 let len = s.len(); 419 write!(&mut s, "{}", len).expect("unexpected"); 420 } 421 s.truncate(len); 422 s 423 } 424 425 #[test] 426 #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522 read_exact_bytes_from_slice()427 fn read_exact_bytes_from_slice() { 428 let bytes = make_long_string(100); 429 let mut bri = BufReadIter::from_byte_slice(&bytes[..]); 430 assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]); 431 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 432 } 433 434 #[test] 435 #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522 read_exact_bytes_from_bytes()436 fn read_exact_bytes_from_bytes() { 437 let bytes = Bytes::from(make_long_string(100)); 438 let mut bri = BufReadIter::from_bytes(&bytes); 439 let read = bri.read_exact_bytes(90).unwrap(); 440 assert_eq!(&bytes[..90], &read[..]); 441 assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr()); 442 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 443 } 444 } 445 446 #[cfg(test)] 447 mod test { 448 use std::io; 449 use std::io::BufRead; 450 use std::io::Read; 451 452 use super::*; 453 454 #[test] eof_at_limit()455 fn eof_at_limit() { 456 struct Read5ThenPanic { 457 pos: usize, 458 } 459 460 impl Read for Read5ThenPanic { 461 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> { 462 unreachable!(); 463 } 464 } 465 466 impl BufRead for Read5ThenPanic { 467 fn fill_buf(&mut self) -> io::Result<&[u8]> { 468 assert_eq!(0, self.pos); 469 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4]; 470 Ok(ZERO_TO_FIVE) 471 } 472 473 fn consume(&mut self, amt: usize) { 474 if amt == 0 { 475 // drop of BufReadIter 476 return; 477 } 478 479 assert_eq!(0, self.pos); 480 assert_eq!(5, amt); 481 self.pos += amt; 482 } 483 } 484 485 let mut read = Read5ThenPanic { pos: 0 }; 486 let mut buf_read_iter = BufReadIter::from_buf_read(&mut read); 487 assert_eq!(0, buf_read_iter.pos()); 488 let _prev_limit = buf_read_iter.push_limit(5); 489 buf_read_iter.read_byte().expect("read_byte"); 490 buf_read_iter 491 .read_exact(&mut [ 492 MaybeUninit::uninit(), 493 MaybeUninit::uninit(), 494 MaybeUninit::uninit(), 495 MaybeUninit::uninit(), 496 ]) 497 .expect("read_exact"); 498 assert!(buf_read_iter.eof().expect("eof")); 499 } 500 } 501