1 use std::cmp; 2 use std::io::BufRead; 3 use std::io::BufReader; 4 use std::io::Read; 5 use std::mem; 6 use std::mem::MaybeUninit; 7 use std::u64; 8 9 #[cfg(feature = "bytes")] 10 use bytes::buf::UninitSlice; 11 #[cfg(feature = "bytes")] 12 use bytes::BufMut; 13 #[cfg(feature = "bytes")] 14 use bytes::Bytes; 15 #[cfg(feature = "bytes")] 16 use bytes::BytesMut; 17 18 use crate::buf_read_or_reader::BufReadOrReader; 19 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC; 20 use crate::error::WireError; 21 use crate::misc::maybe_uninit_write_slice; 22 use crate::misc::vec_spare_capacity_mut; 23 use crate::ProtobufError; 24 use crate::ProtobufResult; 25 26 // If an input stream is constructed with a `Read`, we create a 27 // `BufReader` with an internal buffer of this size. 28 const INPUT_STREAM_BUFFER_SIZE: usize = 4096; 29 30 const USE_UNSAFE_FOR_SPEED: bool = true; 31 32 const NO_LIMIT: u64 = u64::MAX; 33 34 /// Hold all possible combinations of input source 35 enum InputSource<'a> { 36 Read(BufReadOrReader<'a>), 37 Slice(&'a [u8]), 38 #[cfg(feature = "bytes")] 39 Bytes(&'a Bytes), 40 } 41 42 /// Dangerous implementation of `BufRead`. 43 /// 44 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is 45 /// not moved when `BufRead` is moved. 46 /// 47 /// This assumption is generally incorrect, however, in practice 48 /// `BufReadIter` is created either from `BufRead` reference (which 49 /// cannot be moved, because it is locked by `CodedInputStream`) or from 50 /// `BufReader` which does not move its buffer (we know that from 51 /// inspecting rust standard library). 52 /// 53 /// It is important for `CodedInputStream` performance that small reads 54 /// (e. g. 4 bytes reads) do not involve virtual calls or switches. 55 /// This is achievable with `BufReadIter`. 56 pub(crate) struct BufReadIter<'a> { 57 input_source: InputSource<'a>, 58 buf: &'a [u8], 59 pos_within_buf: usize, 60 limit_within_buf: usize, 61 pos_of_buf_start: u64, 62 limit: u64, 63 } 64 65 impl<'a> Drop for BufReadIter<'a> { drop(&mut self)66 fn drop(&mut self) { 67 match self.input_source { 68 InputSource::Read(ref mut buf_read) => buf_read.consume(self.pos_within_buf), 69 _ => {} 70 } 71 } 72 } 73 74 impl<'ignore> BufReadIter<'ignore> { from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a>75 pub(crate) fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> { 76 BufReadIter { 77 input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity( 78 INPUT_STREAM_BUFFER_SIZE, 79 read, 80 ))), 81 buf: &[], 82 pos_within_buf: 0, 83 limit_within_buf: 0, 84 pos_of_buf_start: 0, 85 limit: NO_LIMIT, 86 } 87 } 88 from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>89 pub(crate) fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { 90 BufReadIter { 91 input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)), 92 buf: &[], 93 pos_within_buf: 0, 94 limit_within_buf: 0, 95 pos_of_buf_start: 0, 96 limit: NO_LIMIT, 97 } 98 } 99 from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a>100 pub(crate) fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> { 101 BufReadIter { 102 input_source: InputSource::Slice(bytes), 103 buf: bytes, 104 pos_within_buf: 0, 105 limit_within_buf: bytes.len(), 106 pos_of_buf_start: 0, 107 limit: NO_LIMIT, 108 } 109 } 110 111 #[cfg(feature = "bytes")] from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a>112 pub(crate) fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> { 113 BufReadIter { 114 input_source: InputSource::Bytes(bytes), 115 buf: &bytes, 116 pos_within_buf: 0, 117 limit_within_buf: bytes.len(), 118 pos_of_buf_start: 0, 119 limit: NO_LIMIT, 120 } 121 } 122 123 #[inline] assertions(&self)124 fn assertions(&self) { 125 debug_assert!(self.pos_within_buf <= self.limit_within_buf); 126 debug_assert!(self.limit_within_buf <= self.buf.len()); 127 debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit); 128 } 129 130 #[inline(always)] pos(&self) -> u64131 pub(crate) fn pos(&self) -> u64 { 132 self.pos_of_buf_start + self.pos_within_buf as u64 133 } 134 135 /// Recompute `limit_within_buf` after update of `limit` 136 #[inline] update_limit_within_buf(&mut self)137 fn update_limit_within_buf(&mut self) { 138 if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit { 139 self.limit_within_buf = self.buf.len(); 140 } else { 141 self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize; 142 } 143 144 self.assertions(); 145 } 146 push_limit(&mut self, limit: u64) -> ProtobufResult<u64>147 pub(crate) fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> { 148 let new_limit = match self.pos().checked_add(limit) { 149 Some(new_limit) => new_limit, 150 None => return Err(ProtobufError::WireError(WireError::Other)), 151 }; 152 153 if new_limit > self.limit { 154 return Err(ProtobufError::WireError(WireError::Other)); 155 } 156 157 let prev_limit = mem::replace(&mut self.limit, new_limit); 158 159 self.update_limit_within_buf(); 160 161 Ok(prev_limit) 162 } 163 164 #[inline] pop_limit(&mut self, limit: u64)165 pub(crate) fn pop_limit(&mut self, limit: u64) { 166 assert!(limit >= self.limit); 167 168 self.limit = limit; 169 170 self.update_limit_within_buf(); 171 } 172 173 #[inline] remaining_in_buf(&self) -> &[u8]174 pub(crate) fn remaining_in_buf(&self) -> &[u8] { 175 if USE_UNSAFE_FOR_SPEED { 176 unsafe { 177 &self 178 .buf 179 .get_unchecked(self.pos_within_buf..self.limit_within_buf) 180 } 181 } else { 182 &self.buf[self.pos_within_buf..self.limit_within_buf] 183 } 184 } 185 186 #[inline(always)] remaining_in_buf_len(&self) -> usize187 pub(crate) fn remaining_in_buf_len(&self) -> usize { 188 self.limit_within_buf - self.pos_within_buf 189 } 190 191 #[inline(always)] bytes_until_limit(&self) -> u64192 pub(crate) fn bytes_until_limit(&self) -> u64 { 193 if self.limit == NO_LIMIT { 194 NO_LIMIT 195 } else { 196 self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64) 197 } 198 } 199 200 #[inline(always)] eof(&mut self) -> ProtobufResult<bool>201 pub(crate) fn eof(&mut self) -> ProtobufResult<bool> { 202 if self.pos_within_buf == self.limit_within_buf { 203 Ok(self.fill_buf()?.is_empty()) 204 } else { 205 Ok(false) 206 } 207 } 208 209 #[inline(always)] read_byte(&mut self) -> ProtobufResult<u8>210 pub(crate) fn read_byte(&mut self) -> ProtobufResult<u8> { 211 if self.pos_within_buf == self.limit_within_buf { 212 self.do_fill_buf()?; 213 if self.remaining_in_buf_len() == 0 { 214 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 215 } 216 } 217 218 let r = if USE_UNSAFE_FOR_SPEED { 219 unsafe { *self.buf.get_unchecked(self.pos_within_buf) } 220 } else { 221 self.buf[self.pos_within_buf] 222 }; 223 self.pos_within_buf += 1; 224 Ok(r) 225 } 226 227 /// Read at most `max` bytes, append to `Vec`. 228 /// 229 /// Returns 0 when EOF or limit reached. read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize>230 fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> { 231 let len = { 232 let rem = self.fill_buf()?; 233 234 let len = cmp::min(rem.len(), max); 235 vec.extend_from_slice(&rem[..len]); 236 len 237 }; 238 self.pos_within_buf += len; 239 Ok(len) 240 } 241 242 #[cfg(feature = "bytes")] read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes>243 pub(crate) fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> { 244 if let InputSource::Bytes(bytes) = self.input_source { 245 let end = match self.pos_within_buf.checked_add(len) { 246 Some(end) => end, 247 None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)), 248 }; 249 250 if end > self.limit_within_buf { 251 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 252 } 253 254 let r = bytes.slice(self.pos_within_buf..end); 255 self.pos_within_buf += len; 256 Ok(r) 257 } else { 258 if len >= READ_RAW_BYTES_MAX_ALLOC { 259 // We cannot trust `len` because protobuf message could be malformed. 260 // Reading should not result in OOM when allocating a buffer. 261 let mut v = Vec::new(); 262 self.read_exact_to_vec(len, &mut v)?; 263 Ok(Bytes::from(v)) 264 } else { 265 let mut r = BytesMut::with_capacity(len); 266 unsafe { 267 let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]); 268 self.read_exact(buf)?; 269 r.advance_mut(len); 270 } 271 Ok(r.freeze()) 272 } 273 } 274 } 275 276 #[cfg(feature = "bytes")] uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>]277 unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] { 278 use std::slice; 279 slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len()) 280 } 281 282 /// Returns 0 when EOF or limit reached. read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize>283 pub(crate) fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> { 284 self.fill_buf()?; 285 286 let rem = &self.buf[self.pos_within_buf..self.limit_within_buf]; 287 288 let len = cmp::min(rem.len(), buf.len()); 289 buf[..len].copy_from_slice(&rem[..len]); 290 self.pos_within_buf += len; 291 Ok(len) 292 } 293 read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()>294 fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> { 295 if self.bytes_until_limit() < buf.len() as u64 { 296 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 297 } 298 299 let consume = self.pos_within_buf; 300 self.pos_of_buf_start += self.pos_within_buf as u64; 301 self.pos_within_buf = 0; 302 self.buf = &[]; 303 self.limit_within_buf = 0; 304 305 match self.input_source { 306 InputSource::Read(ref mut buf_read) => { 307 buf_read.consume(consume); 308 buf_read.read_exact_uninit(buf)?; 309 } 310 _ => { 311 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 312 } 313 } 314 315 self.pos_of_buf_start += buf.len() as u64; 316 317 self.assertions(); 318 319 Ok(()) 320 } 321 322 #[inline] read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()>323 pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> { 324 if self.remaining_in_buf_len() >= buf.len() { 325 let buf_len = buf.len(); 326 maybe_uninit_write_slice( 327 buf, 328 &self.buf[self.pos_within_buf..self.pos_within_buf + buf_len], 329 ); 330 self.pos_within_buf += buf_len; 331 return Ok(()); 332 } 333 334 self.read_exact_slow(buf) 335 } 336 337 /// Read exact number of bytes into `Vec`. 338 /// 339 /// `Vec` is cleared in the beginning. read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()>340 pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> { 341 // TODO: also do some limits when reading from unlimited source 342 if count as u64 > self.bytes_until_limit() { 343 return Err(ProtobufError::WireError(WireError::TruncatedMessage)); 344 } 345 346 target.clear(); 347 348 if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { 349 // avoid calling `reserve` on buf with very large buffer: could be a malformed message 350 351 target.reserve(READ_RAW_BYTES_MAX_ALLOC); 352 353 while target.len() < count { 354 let need_to_read = count - target.len(); 355 if need_to_read <= target.len() { 356 target.reserve_exact(need_to_read); 357 } else { 358 target.reserve(1); 359 } 360 361 let max = cmp::min(target.capacity() - target.len(), need_to_read); 362 let read = self.read_to_vec(target, max)?; 363 if read == 0 { 364 return Err(ProtobufError::WireError(WireError::TruncatedMessage)); 365 } 366 } 367 } else { 368 target.reserve_exact(count); 369 370 unsafe { 371 self.read_exact(&mut vec_spare_capacity_mut(target)[..count])?; 372 target.set_len(count); 373 } 374 } 375 376 debug_assert_eq!(count, target.len()); 377 378 Ok(()) 379 } 380 do_fill_buf(&mut self) -> ProtobufResult<()>381 fn do_fill_buf(&mut self) -> ProtobufResult<()> { 382 debug_assert!(self.pos_within_buf == self.limit_within_buf); 383 384 // Limit is reached, do not fill buf, because otherwise 385 // synchronous read from `CodedInputStream` may block. 386 if self.limit == self.pos() { 387 return Ok(()); 388 } 389 390 let consume = self.buf.len(); 391 self.pos_of_buf_start += self.buf.len() as u64; 392 self.buf = &[]; 393 self.pos_within_buf = 0; 394 self.limit_within_buf = 0; 395 396 match self.input_source { 397 InputSource::Read(ref mut buf_read) => { 398 buf_read.consume(consume); 399 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; 400 } 401 _ => { 402 return Ok(()); 403 } 404 } 405 406 self.update_limit_within_buf(); 407 408 Ok(()) 409 } 410 411 #[inline(always)] fill_buf(&mut self) -> ProtobufResult<&[u8]>412 pub(crate) fn fill_buf(&mut self) -> ProtobufResult<&[u8]> { 413 if self.pos_within_buf == self.limit_within_buf { 414 self.do_fill_buf()?; 415 } 416 417 Ok(if USE_UNSAFE_FOR_SPEED { 418 unsafe { 419 self.buf 420 .get_unchecked(self.pos_within_buf..self.limit_within_buf) 421 } 422 } else { 423 &self.buf[self.pos_within_buf..self.limit_within_buf] 424 }) 425 } 426 427 #[inline(always)] consume(&mut self, amt: usize)428 pub(crate) fn consume(&mut self, amt: usize) { 429 assert!(amt <= self.limit_within_buf - self.pos_within_buf); 430 self.pos_within_buf += amt; 431 } 432 } 433 434 #[cfg(all(test, feature = "bytes"))] 435 mod test_bytes { 436 use std::io::Write; 437 438 use super::*; 439 make_long_string(len: usize) -> Vec<u8>440 fn make_long_string(len: usize) -> Vec<u8> { 441 let mut s = Vec::new(); 442 while s.len() < len { 443 let len = s.len(); 444 write!(&mut s, "{}", len).expect("unexpected"); 445 } 446 s.truncate(len); 447 s 448 } 449 450 #[test] read_exact_bytes_from_slice()451 fn read_exact_bytes_from_slice() { 452 let bytes = make_long_string(100); 453 let mut bri = BufReadIter::from_byte_slice(&bytes[..]); 454 assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]); 455 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 456 } 457 458 #[test] read_exact_bytes_from_bytes()459 fn read_exact_bytes_from_bytes() { 460 let bytes = Bytes::from(make_long_string(100)); 461 let mut bri = BufReadIter::from_bytes(&bytes); 462 let read = bri.read_exact_bytes(90).unwrap(); 463 assert_eq!(&bytes[..90], &read[..]); 464 assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr()); 465 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 466 } 467 } 468 469 #[cfg(test)] 470 mod test { 471 use std::io; 472 use std::io::BufRead; 473 use std::io::Read; 474 475 use super::*; 476 477 #[test] eof_at_limit()478 fn eof_at_limit() { 479 struct Read5ThenPanic { 480 pos: usize, 481 } 482 483 impl Read for Read5ThenPanic { 484 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> { 485 unreachable!(); 486 } 487 } 488 489 impl BufRead for Read5ThenPanic { 490 fn fill_buf(&mut self) -> io::Result<&[u8]> { 491 assert_eq!(0, self.pos); 492 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4]; 493 Ok(ZERO_TO_FIVE) 494 } 495 496 fn consume(&mut self, amt: usize) { 497 if amt == 0 { 498 // drop of BufReadIter 499 return; 500 } 501 502 assert_eq!(0, self.pos); 503 assert_eq!(5, amt); 504 self.pos += amt; 505 } 506 } 507 508 let mut read = Read5ThenPanic { pos: 0 }; 509 let mut buf_read_iter = BufReadIter::from_buf_read(&mut read); 510 assert_eq!(0, buf_read_iter.pos()); 511 let _prev_limit = buf_read_iter.push_limit(5); 512 buf_read_iter.read_byte().expect("read_byte"); 513 buf_read_iter 514 .read_exact(&mut [ 515 MaybeUninit::uninit(), 516 MaybeUninit::uninit(), 517 MaybeUninit::uninit(), 518 MaybeUninit::uninit(), 519 ]) 520 .expect("read_exact"); 521 assert!(buf_read_iter.eof().expect("eof")); 522 } 523 } 524