1 use std::cmp; 2 use std::io::BufRead; 3 use std::io::BufReader; 4 use std::io::Read; 5 use std::mem; 6 use std::u64; 7 8 #[cfg(feature = "bytes")] 9 use bytes::buf::UninitSlice; 10 #[cfg(feature = "bytes")] 11 use bytes::BufMut; 12 #[cfg(feature = "bytes")] 13 use bytes::Bytes; 14 #[cfg(feature = "bytes")] 15 use bytes::BytesMut; 16 17 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC; 18 use crate::error::WireError; 19 use crate::ProtobufError; 20 use crate::ProtobufResult; 21 22 // If an input stream is constructed with a `Read`, we create a 23 // `BufReader` with an internal buffer of this size. 24 const INPUT_STREAM_BUFFER_SIZE: usize = 4096; 25 26 const USE_UNSAFE_FOR_SPEED: bool = true; 27 28 const NO_LIMIT: u64 = u64::MAX; 29 30 /// Hold all possible combinations of input source 31 enum InputSource<'a> { 32 BufRead(&'a mut dyn BufRead), 33 Read(BufReader<&'a mut dyn Read>), 34 Slice(&'a [u8]), 35 #[cfg(feature = "bytes")] 36 Bytes(&'a Bytes), 37 } 38 39 /// Dangerous implementation of `BufRead`. 40 /// 41 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is 42 /// not moved when `BufRead` is moved. 43 /// 44 /// This assumption is generally incorrect, however, in practice 45 /// `BufReadIter` is created either from `BufRead` reference (which 46 /// cannot be moved, because it is locked by `CodedInputStream`) or from 47 /// `BufReader` which does not move its buffer (we know that from 48 /// inspecting rust standard library). 49 /// 50 /// It is important for `CodedInputStream` performance that small reads 51 /// (e. g. 4 bytes reads) do not involve virtual calls or switches. 52 /// This is achievable with `BufReadIter`. 53 pub struct BufReadIter<'a> { 54 input_source: InputSource<'a>, 55 buf: &'a [u8], 56 pos_within_buf: usize, 57 limit_within_buf: usize, 58 pos_of_buf_start: u64, 59 limit: u64, 60 } 61 62 impl<'a> Drop for BufReadIter<'a> { drop(&mut self)63 fn drop(&mut self) { 64 match self.input_source { 65 InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf), 66 InputSource::Read(_) => { 67 // Nothing to flush, because we own BufReader 68 } 69 _ => {} 70 } 71 } 72 } 73 74 impl<'ignore> BufReadIter<'ignore> { from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a>75 pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> { 76 BufReadIter { 77 input_source: InputSource::Read(BufReader::with_capacity( 78 INPUT_STREAM_BUFFER_SIZE, 79 read, 80 )), 81 buf: &[], 82 pos_within_buf: 0, 83 limit_within_buf: 0, 84 pos_of_buf_start: 0, 85 limit: NO_LIMIT, 86 } 87 } 88 from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>89 pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { 90 BufReadIter { 91 input_source: InputSource::BufRead(buf_read), 92 buf: &[], 93 pos_within_buf: 0, 94 limit_within_buf: 0, 95 pos_of_buf_start: 0, 96 limit: NO_LIMIT, 97 } 98 } 99 from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a>100 pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> { 101 BufReadIter { 102 input_source: InputSource::Slice(bytes), 103 buf: bytes, 104 pos_within_buf: 0, 105 limit_within_buf: bytes.len(), 106 pos_of_buf_start: 0, 107 limit: NO_LIMIT, 108 } 109 } 110 111 #[cfg(feature = "bytes")] from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a>112 pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> { 113 BufReadIter { 114 input_source: InputSource::Bytes(bytes), 115 buf: &bytes, 116 pos_within_buf: 0, 117 limit_within_buf: bytes.len(), 118 pos_of_buf_start: 0, 119 limit: NO_LIMIT, 120 } 121 } 122 123 #[inline] assertions(&self)124 fn assertions(&self) { 125 debug_assert!(self.pos_within_buf <= self.limit_within_buf); 126 debug_assert!(self.limit_within_buf <= self.buf.len()); 127 debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit); 128 } 129 130 #[inline(always)] pos(&self) -> u64131 pub fn pos(&self) -> u64 { 132 self.pos_of_buf_start + self.pos_within_buf as u64 133 } 134 135 /// Recompute `limit_within_buf` after update of `limit` 136 #[inline] update_limit_within_buf(&mut self)137 fn update_limit_within_buf(&mut self) { 138 if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit { 139 self.limit_within_buf = self.buf.len(); 140 } else { 141 self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize; 142 } 143 144 self.assertions(); 145 } 146 push_limit(&mut self, limit: u64) -> ProtobufResult<u64>147 pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> { 148 let new_limit = match self.pos().checked_add(limit) { 149 Some(new_limit) => new_limit, 150 None => return Err(ProtobufError::WireError(WireError::Other)), 151 }; 152 153 if new_limit > self.limit { 154 return Err(ProtobufError::WireError(WireError::Other)); 155 } 156 157 let prev_limit = mem::replace(&mut self.limit, new_limit); 158 159 self.update_limit_within_buf(); 160 161 Ok(prev_limit) 162 } 163 164 #[inline] pop_limit(&mut self, limit: u64)165 pub fn pop_limit(&mut self, limit: u64) { 166 assert!(limit >= self.limit); 167 168 self.limit = limit; 169 170 self.update_limit_within_buf(); 171 } 172 173 #[inline] remaining_in_buf(&self) -> &[u8]174 pub fn remaining_in_buf(&self) -> &[u8] { 175 if USE_UNSAFE_FOR_SPEED { 176 unsafe { 177 &self 178 .buf 179 .get_unchecked(self.pos_within_buf..self.limit_within_buf) 180 } 181 } else { 182 &self.buf[self.pos_within_buf..self.limit_within_buf] 183 } 184 } 185 186 #[inline(always)] remaining_in_buf_len(&self) -> usize187 pub fn remaining_in_buf_len(&self) -> usize { 188 self.limit_within_buf - self.pos_within_buf 189 } 190 191 #[inline(always)] bytes_until_limit(&self) -> u64192 pub fn bytes_until_limit(&self) -> u64 { 193 if self.limit == NO_LIMIT { 194 NO_LIMIT 195 } else { 196 self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64) 197 } 198 } 199 200 #[inline(always)] eof(&mut self) -> ProtobufResult<bool>201 pub fn eof(&mut self) -> ProtobufResult<bool> { 202 if self.pos_within_buf == self.limit_within_buf { 203 Ok(self.fill_buf()?.is_empty()) 204 } else { 205 Ok(false) 206 } 207 } 208 209 #[inline(always)] read_byte(&mut self) -> ProtobufResult<u8>210 pub fn read_byte(&mut self) -> ProtobufResult<u8> { 211 if self.pos_within_buf == self.limit_within_buf { 212 self.do_fill_buf()?; 213 if self.remaining_in_buf_len() == 0 { 214 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 215 } 216 } 217 218 let r = if USE_UNSAFE_FOR_SPEED { 219 unsafe { *self.buf.get_unchecked(self.pos_within_buf) } 220 } else { 221 self.buf[self.pos_within_buf] 222 }; 223 self.pos_within_buf += 1; 224 Ok(r) 225 } 226 227 /// Read at most `max` bytes, append to `Vec`. 228 /// 229 /// Returns 0 when EOF or limit reached. read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize>230 fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> { 231 let len = { 232 let rem = self.fill_buf()?; 233 234 let len = cmp::min(rem.len(), max); 235 vec.extend_from_slice(&rem[..len]); 236 len 237 }; 238 self.pos_within_buf += len; 239 Ok(len) 240 } 241 242 /// Read exact number of bytes into `Vec`. 243 /// 244 /// `Vec` is cleared in the beginning. read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()>245 pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> { 246 // TODO: also do some limits when reading from unlimited source 247 if count as u64 > self.bytes_until_limit() { 248 return Err(ProtobufError::WireError(WireError::TruncatedMessage)); 249 } 250 251 target.clear(); 252 253 if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { 254 // avoid calling `reserve` on buf with very large buffer: could be a malformed message 255 256 target.reserve(READ_RAW_BYTES_MAX_ALLOC); 257 258 while target.len() < count { 259 let need_to_read = count - target.len(); 260 if need_to_read <= target.len() { 261 target.reserve_exact(need_to_read); 262 } else { 263 target.reserve(1); 264 } 265 266 let max = cmp::min(target.capacity() - target.len(), need_to_read); 267 let read = self.read_to_vec(target, max)?; 268 if read == 0 { 269 return Err(ProtobufError::WireError(WireError::TruncatedMessage)); 270 } 271 } 272 } else { 273 target.reserve_exact(count); 274 275 unsafe { 276 self.read_exact(&mut target.get_unchecked_mut(..count))?; 277 target.set_len(count); 278 } 279 } 280 281 debug_assert_eq!(count, target.len()); 282 283 Ok(()) 284 } 285 286 #[cfg(feature = "bytes")] read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes>287 pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> { 288 if let InputSource::Bytes(bytes) = self.input_source { 289 let end = match self.pos_within_buf.checked_add(len) { 290 Some(end) => end, 291 None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)), 292 }; 293 294 if end > self.limit_within_buf { 295 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 296 } 297 298 let r = bytes.slice(self.pos_within_buf..end); 299 self.pos_within_buf += len; 300 Ok(r) 301 } else { 302 if len >= READ_RAW_BYTES_MAX_ALLOC { 303 // We cannot trust `len` because protobuf message could be malformed. 304 // Reading should not result in OOM when allocating a buffer. 305 let mut v = Vec::new(); 306 self.read_exact_to_vec(len, &mut v)?; 307 Ok(Bytes::from(v)) 308 } else { 309 let mut r = BytesMut::with_capacity(len); 310 unsafe { 311 let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]); 312 self.read_exact(buf)?; 313 r.advance_mut(len); 314 } 315 Ok(r.freeze()) 316 } 317 } 318 } 319 320 #[cfg(feature = "bytes")] uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8]321 unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] { 322 use std::slice; 323 slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len()) 324 } 325 326 /// Returns 0 when EOF or limit reached. read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize>327 pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> { 328 self.fill_buf()?; 329 330 let rem = &self.buf[self.pos_within_buf..self.limit_within_buf]; 331 332 let len = cmp::min(rem.len(), buf.len()); 333 &mut buf[..len].copy_from_slice(&rem[..len]); 334 self.pos_within_buf += len; 335 Ok(len) 336 } 337 read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()>338 pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> { 339 if self.remaining_in_buf_len() >= buf.len() { 340 let buf_len = buf.len(); 341 buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]); 342 self.pos_within_buf += buf_len; 343 return Ok(()); 344 } 345 346 if self.bytes_until_limit() < buf.len() as u64 { 347 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 348 } 349 350 let consume = self.pos_within_buf; 351 self.pos_of_buf_start += self.pos_within_buf as u64; 352 self.pos_within_buf = 0; 353 self.buf = &[]; 354 self.limit_within_buf = 0; 355 356 match self.input_source { 357 InputSource::Read(ref mut buf_read) => { 358 buf_read.consume(consume); 359 buf_read.read_exact(buf)?; 360 } 361 InputSource::BufRead(ref mut buf_read) => { 362 buf_read.consume(consume); 363 buf_read.read_exact(buf)?; 364 } 365 _ => { 366 return Err(ProtobufError::WireError(WireError::UnexpectedEof)); 367 } 368 } 369 370 self.pos_of_buf_start += buf.len() as u64; 371 372 self.assertions(); 373 374 Ok(()) 375 } 376 do_fill_buf(&mut self) -> ProtobufResult<()>377 fn do_fill_buf(&mut self) -> ProtobufResult<()> { 378 debug_assert!(self.pos_within_buf == self.limit_within_buf); 379 380 // Limit is reached, do not fill buf, because otherwise 381 // synchronous read from `CodedInputStream` may block. 382 if self.limit == self.pos() { 383 return Ok(()); 384 } 385 386 let consume = self.buf.len(); 387 self.pos_of_buf_start += self.buf.len() as u64; 388 self.buf = &[]; 389 self.pos_within_buf = 0; 390 self.limit_within_buf = 0; 391 392 match self.input_source { 393 InputSource::Read(ref mut buf_read) => { 394 buf_read.consume(consume); 395 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; 396 } 397 InputSource::BufRead(ref mut buf_read) => { 398 buf_read.consume(consume); 399 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; 400 } 401 _ => { 402 return Ok(()); 403 } 404 } 405 406 self.update_limit_within_buf(); 407 408 Ok(()) 409 } 410 411 #[inline(always)] fill_buf(&mut self) -> ProtobufResult<&[u8]>412 pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> { 413 if self.pos_within_buf == self.limit_within_buf { 414 self.do_fill_buf()?; 415 } 416 417 Ok(if USE_UNSAFE_FOR_SPEED { 418 unsafe { 419 self.buf 420 .get_unchecked(self.pos_within_buf..self.limit_within_buf) 421 } 422 } else { 423 &self.buf[self.pos_within_buf..self.limit_within_buf] 424 }) 425 } 426 427 #[inline(always)] consume(&mut self, amt: usize)428 pub fn consume(&mut self, amt: usize) { 429 assert!(amt <= self.limit_within_buf - self.pos_within_buf); 430 self.pos_within_buf += amt; 431 } 432 } 433 434 #[cfg(all(test, feature = "bytes"))] 435 mod test_bytes { 436 use super::*; 437 use std::io::Write; 438 make_long_string(len: usize) -> Vec<u8>439 fn make_long_string(len: usize) -> Vec<u8> { 440 let mut s = Vec::new(); 441 while s.len() < len { 442 let len = s.len(); 443 write!(&mut s, "{}", len).expect("unexpected"); 444 } 445 s.truncate(len); 446 s 447 } 448 449 #[test] read_exact_bytes_from_slice()450 fn read_exact_bytes_from_slice() { 451 let bytes = make_long_string(100); 452 let mut bri = BufReadIter::from_byte_slice(&bytes[..]); 453 assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]); 454 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 455 } 456 457 #[test] read_exact_bytes_from_bytes()458 fn read_exact_bytes_from_bytes() { 459 let bytes = Bytes::from(make_long_string(100)); 460 let mut bri = BufReadIter::from_bytes(&bytes); 461 let read = bri.read_exact_bytes(90).unwrap(); 462 assert_eq!(&bytes[..90], &read[..]); 463 assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr()); 464 assert_eq!(bytes[90], bri.read_byte().expect("read_byte")); 465 } 466 } 467 468 #[cfg(test)] 469 mod test { 470 use super::*; 471 use std::io; 472 use std::io::BufRead; 473 use std::io::Read; 474 475 #[test] eof_at_limit()476 fn eof_at_limit() { 477 struct Read5ThenPanic { 478 pos: usize, 479 } 480 481 impl Read for Read5ThenPanic { 482 fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> { 483 unreachable!(); 484 } 485 } 486 487 impl BufRead for Read5ThenPanic { 488 fn fill_buf(&mut self) -> io::Result<&[u8]> { 489 assert_eq!(0, self.pos); 490 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4]; 491 Ok(ZERO_TO_FIVE) 492 } 493 494 fn consume(&mut self, amt: usize) { 495 if amt == 0 { 496 // drop of BufReadIter 497 return; 498 } 499 500 assert_eq!(0, self.pos); 501 assert_eq!(5, amt); 502 self.pos += amt; 503 } 504 } 505 506 let mut read = Read5ThenPanic { pos: 0 }; 507 let mut buf_read_iter = BufReadIter::from_buf_read(&mut read); 508 assert_eq!(0, buf_read_iter.pos()); 509 let _prev_limit = buf_read_iter.push_limit(5); 510 buf_read_iter.read_byte().expect("read_byte"); 511 buf_read_iter 512 .read_exact(&mut [1, 2, 3, 4]) 513 .expect("read_exact"); 514 assert!(buf_read_iter.eof().expect("eof")); 515 } 516 } 517