• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::cmp;
2 use std::io::BufRead;
3 use std::io::BufReader;
4 use std::io::Read;
5 use std::mem;
6 use std::mem::MaybeUninit;
7 use std::u64;
8 
9 #[cfg(feature = "bytes")]
10 use bytes::buf::UninitSlice;
11 #[cfg(feature = "bytes")]
12 use bytes::BufMut;
13 #[cfg(feature = "bytes")]
14 use bytes::Bytes;
15 #[cfg(feature = "bytes")]
16 use bytes::BytesMut;
17 
18 use crate::buf_read_or_reader::BufReadOrReader;
19 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
20 use crate::error::WireError;
21 use crate::misc::maybe_uninit_write_slice;
22 use crate::misc::vec_spare_capacity_mut;
23 use crate::ProtobufError;
24 use crate::ProtobufResult;
25 
26 // If an input stream is constructed with a `Read`, we create a
27 // `BufReader` with an internal buffer of this size.
28 const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
29 
30 const USE_UNSAFE_FOR_SPEED: bool = true;
31 
32 const NO_LIMIT: u64 = u64::MAX;
33 
34 /// Hold all possible combinations of input source
35 enum InputSource<'a> {
36     Read(BufReadOrReader<'a>),
37     Slice(&'a [u8]),
38     #[cfg(feature = "bytes")]
39     Bytes(&'a Bytes),
40 }
41 
42 /// Dangerous implementation of `BufRead`.
43 ///
44 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is
45 /// not moved when `BufRead` is moved.
46 ///
47 /// This assumption is generally incorrect, however, in practice
48 /// `BufReadIter` is created either from `BufRead` reference (which
49 /// cannot  be moved, because it is locked by `CodedInputStream`) or from
50 /// `BufReader` which does not move its buffer (we know that from
51 /// inspecting rust standard library).
52 ///
53 /// It is important for `CodedInputStream` performance that small reads
54 /// (e. g. 4 bytes reads) do not involve virtual calls or switches.
55 /// This is achievable with `BufReadIter`.
56 pub(crate) struct BufReadIter<'a> {
57     input_source: InputSource<'a>,
58     buf: &'a [u8],
59     pos_within_buf: usize,
60     limit_within_buf: usize,
61     pos_of_buf_start: u64,
62     limit: u64,
63 }
64 
65 impl<'a> Drop for BufReadIter<'a> {
drop(&mut self)66     fn drop(&mut self) {
67         match self.input_source {
68             InputSource::Read(ref mut buf_read) => buf_read.consume(self.pos_within_buf),
69             _ => {}
70         }
71     }
72 }
73 
74 impl<'ignore> BufReadIter<'ignore> {
from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a>75     pub(crate) fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> {
76         BufReadIter {
77             input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity(
78                 INPUT_STREAM_BUFFER_SIZE,
79                 read,
80             ))),
81             buf: &[],
82             pos_within_buf: 0,
83             limit_within_buf: 0,
84             pos_of_buf_start: 0,
85             limit: NO_LIMIT,
86         }
87     }
88 
from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>89     pub(crate) fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
90         BufReadIter {
91             input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)),
92             buf: &[],
93             pos_within_buf: 0,
94             limit_within_buf: 0,
95             pos_of_buf_start: 0,
96             limit: NO_LIMIT,
97         }
98     }
99 
from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a>100     pub(crate) fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> {
101         BufReadIter {
102             input_source: InputSource::Slice(bytes),
103             buf: bytes,
104             pos_within_buf: 0,
105             limit_within_buf: bytes.len(),
106             pos_of_buf_start: 0,
107             limit: NO_LIMIT,
108         }
109     }
110 
111     #[cfg(feature = "bytes")]
from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a>112     pub(crate) fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> {
113         BufReadIter {
114             input_source: InputSource::Bytes(bytes),
115             buf: &bytes,
116             pos_within_buf: 0,
117             limit_within_buf: bytes.len(),
118             pos_of_buf_start: 0,
119             limit: NO_LIMIT,
120         }
121     }
122 
123     #[inline]
assertions(&self)124     fn assertions(&self) {
125         debug_assert!(self.pos_within_buf <= self.limit_within_buf);
126         debug_assert!(self.limit_within_buf <= self.buf.len());
127         debug_assert!(self.pos_of_buf_start + self.pos_within_buf as u64 <= self.limit);
128     }
129 
130     #[inline(always)]
pos(&self) -> u64131     pub(crate) fn pos(&self) -> u64 {
132         self.pos_of_buf_start + self.pos_within_buf as u64
133     }
134 
135     /// Recompute `limit_within_buf` after update of `limit`
136     #[inline]
update_limit_within_buf(&mut self)137     fn update_limit_within_buf(&mut self) {
138         if self.pos_of_buf_start + (self.buf.len() as u64) <= self.limit {
139             self.limit_within_buf = self.buf.len();
140         } else {
141             self.limit_within_buf = (self.limit - self.pos_of_buf_start) as usize;
142         }
143 
144         self.assertions();
145     }
146 
push_limit(&mut self, limit: u64) -> ProtobufResult<u64>147     pub(crate) fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> {
148         let new_limit = match self.pos().checked_add(limit) {
149             Some(new_limit) => new_limit,
150             None => return Err(ProtobufError::WireError(WireError::Other)),
151         };
152 
153         if new_limit > self.limit {
154             return Err(ProtobufError::WireError(WireError::Other));
155         }
156 
157         let prev_limit = mem::replace(&mut self.limit, new_limit);
158 
159         self.update_limit_within_buf();
160 
161         Ok(prev_limit)
162     }
163 
164     #[inline]
pop_limit(&mut self, limit: u64)165     pub(crate) fn pop_limit(&mut self, limit: u64) {
166         assert!(limit >= self.limit);
167 
168         self.limit = limit;
169 
170         self.update_limit_within_buf();
171     }
172 
173     #[inline]
remaining_in_buf(&self) -> &[u8]174     pub(crate) fn remaining_in_buf(&self) -> &[u8] {
175         if USE_UNSAFE_FOR_SPEED {
176             unsafe {
177                 &self
178                     .buf
179                     .get_unchecked(self.pos_within_buf..self.limit_within_buf)
180             }
181         } else {
182             &self.buf[self.pos_within_buf..self.limit_within_buf]
183         }
184     }
185 
186     #[inline(always)]
remaining_in_buf_len(&self) -> usize187     pub(crate) fn remaining_in_buf_len(&self) -> usize {
188         self.limit_within_buf - self.pos_within_buf
189     }
190 
191     #[inline(always)]
bytes_until_limit(&self) -> u64192     pub(crate) fn bytes_until_limit(&self) -> u64 {
193         if self.limit == NO_LIMIT {
194             NO_LIMIT
195         } else {
196             self.limit - (self.pos_of_buf_start + self.pos_within_buf as u64)
197         }
198     }
199 
200     #[inline(always)]
eof(&mut self) -> ProtobufResult<bool>201     pub(crate) fn eof(&mut self) -> ProtobufResult<bool> {
202         if self.pos_within_buf == self.limit_within_buf {
203             Ok(self.fill_buf()?.is_empty())
204         } else {
205             Ok(false)
206         }
207     }
208 
209     #[inline(always)]
read_byte(&mut self) -> ProtobufResult<u8>210     pub(crate) fn read_byte(&mut self) -> ProtobufResult<u8> {
211         if self.pos_within_buf == self.limit_within_buf {
212             self.do_fill_buf()?;
213             if self.remaining_in_buf_len() == 0 {
214                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
215             }
216         }
217 
218         let r = if USE_UNSAFE_FOR_SPEED {
219             unsafe { *self.buf.get_unchecked(self.pos_within_buf) }
220         } else {
221             self.buf[self.pos_within_buf]
222         };
223         self.pos_within_buf += 1;
224         Ok(r)
225     }
226 
227     /// Read at most `max` bytes, append to `Vec`.
228     ///
229     /// Returns 0 when EOF or limit reached.
read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize>230     fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> ProtobufResult<usize> {
231         let len = {
232             let rem = self.fill_buf()?;
233 
234             let len = cmp::min(rem.len(), max);
235             vec.extend_from_slice(&rem[..len]);
236             len
237         };
238         self.pos_within_buf += len;
239         Ok(len)
240     }
241 
242     #[cfg(feature = "bytes")]
read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes>243     pub(crate) fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> {
244         if let InputSource::Bytes(bytes) = self.input_source {
245             let end = match self.pos_within_buf.checked_add(len) {
246                 Some(end) => end,
247                 None => return Err(ProtobufError::WireError(WireError::UnexpectedEof)),
248             };
249 
250             if end > self.limit_within_buf {
251                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
252             }
253 
254             let r = bytes.slice(self.pos_within_buf..end);
255             self.pos_within_buf += len;
256             Ok(r)
257         } else {
258             if len >= READ_RAW_BYTES_MAX_ALLOC {
259                 // We cannot trust `len` because protobuf message could be malformed.
260                 // Reading should not result in OOM when allocating a buffer.
261                 let mut v = Vec::new();
262                 self.read_exact_to_vec(len, &mut v)?;
263                 Ok(Bytes::from(v))
264             } else {
265                 let mut r = BytesMut::with_capacity(len);
266                 unsafe {
267                     let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
268                     self.read_exact(buf)?;
269                     r.advance_mut(len);
270                 }
271                 Ok(r.freeze())
272             }
273         }
274     }
275 
276     #[cfg(feature = "bytes")]
uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>]277     unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] {
278         use std::slice;
279         slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len())
280     }
281 
282     /// Returns 0 when EOF or limit reached.
read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize>283     pub(crate) fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> {
284         self.fill_buf()?;
285 
286         let rem = &self.buf[self.pos_within_buf..self.limit_within_buf];
287 
288         let len = cmp::min(rem.len(), buf.len());
289         buf[..len].copy_from_slice(&rem[..len]);
290         self.pos_within_buf += len;
291         Ok(len)
292     }
293 
read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()>294     fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> {
295         if self.bytes_until_limit() < buf.len() as u64 {
296             return Err(ProtobufError::WireError(WireError::UnexpectedEof));
297         }
298 
299         let consume = self.pos_within_buf;
300         self.pos_of_buf_start += self.pos_within_buf as u64;
301         self.pos_within_buf = 0;
302         self.buf = &[];
303         self.limit_within_buf = 0;
304 
305         match self.input_source {
306             InputSource::Read(ref mut buf_read) => {
307                 buf_read.consume(consume);
308                 buf_read.read_exact_uninit(buf)?;
309             }
310             _ => {
311                 return Err(ProtobufError::WireError(WireError::UnexpectedEof));
312             }
313         }
314 
315         self.pos_of_buf_start += buf.len() as u64;
316 
317         self.assertions();
318 
319         Ok(())
320     }
321 
322     #[inline]
read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()>323     pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> {
324         if self.remaining_in_buf_len() >= buf.len() {
325             let buf_len = buf.len();
326             maybe_uninit_write_slice(
327                 buf,
328                 &self.buf[self.pos_within_buf..self.pos_within_buf + buf_len],
329             );
330             self.pos_within_buf += buf_len;
331             return Ok(());
332         }
333 
334         self.read_exact_slow(buf)
335     }
336 
337     /// Read exact number of bytes into `Vec`.
338     ///
339     /// `Vec` is cleared in the beginning.
read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()>340     pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> {
341         // TODO: also do some limits when reading from unlimited source
342         if count as u64 > self.bytes_until_limit() {
343             return Err(ProtobufError::WireError(WireError::TruncatedMessage));
344         }
345 
346         target.clear();
347 
348         if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
349             // avoid calling `reserve` on buf with very large buffer: could be a malformed message
350 
351             target.reserve(READ_RAW_BYTES_MAX_ALLOC);
352 
353             while target.len() < count {
354                 let need_to_read = count - target.len();
355                 if need_to_read <= target.len() {
356                     target.reserve_exact(need_to_read);
357                 } else {
358                     target.reserve(1);
359                 }
360 
361                 let max = cmp::min(target.capacity() - target.len(), need_to_read);
362                 let read = self.read_to_vec(target, max)?;
363                 if read == 0 {
364                     return Err(ProtobufError::WireError(WireError::TruncatedMessage));
365                 }
366             }
367         } else {
368             target.reserve_exact(count);
369 
370             unsafe {
371                 self.read_exact(&mut vec_spare_capacity_mut(target)[..count])?;
372                 target.set_len(count);
373             }
374         }
375 
376         debug_assert_eq!(count, target.len());
377 
378         Ok(())
379     }
380 
do_fill_buf(&mut self) -> ProtobufResult<()>381     fn do_fill_buf(&mut self) -> ProtobufResult<()> {
382         debug_assert!(self.pos_within_buf == self.limit_within_buf);
383 
384         // Limit is reached, do not fill buf, because otherwise
385         // synchronous read from `CodedInputStream` may block.
386         if self.limit == self.pos() {
387             return Ok(());
388         }
389 
390         let consume = self.buf.len();
391         self.pos_of_buf_start += self.buf.len() as u64;
392         self.buf = &[];
393         self.pos_within_buf = 0;
394         self.limit_within_buf = 0;
395 
396         match self.input_source {
397             InputSource::Read(ref mut buf_read) => {
398                 buf_read.consume(consume);
399                 self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
400             }
401             _ => {
402                 return Ok(());
403             }
404         }
405 
406         self.update_limit_within_buf();
407 
408         Ok(())
409     }
410 
411     #[inline(always)]
fill_buf(&mut self) -> ProtobufResult<&[u8]>412     pub(crate) fn fill_buf(&mut self) -> ProtobufResult<&[u8]> {
413         if self.pos_within_buf == self.limit_within_buf {
414             self.do_fill_buf()?;
415         }
416 
417         Ok(if USE_UNSAFE_FOR_SPEED {
418             unsafe {
419                 self.buf
420                     .get_unchecked(self.pos_within_buf..self.limit_within_buf)
421             }
422         } else {
423             &self.buf[self.pos_within_buf..self.limit_within_buf]
424         })
425     }
426 
427     #[inline(always)]
consume(&mut self, amt: usize)428     pub(crate) fn consume(&mut self, amt: usize) {
429         assert!(amt <= self.limit_within_buf - self.pos_within_buf);
430         self.pos_within_buf += amt;
431     }
432 }
433 
434 #[cfg(all(test, feature = "bytes"))]
435 mod test_bytes {
436     use std::io::Write;
437 
438     use super::*;
439 
make_long_string(len: usize) -> Vec<u8>440     fn make_long_string(len: usize) -> Vec<u8> {
441         let mut s = Vec::new();
442         while s.len() < len {
443             let len = s.len();
444             write!(&mut s, "{}", len).expect("unexpected");
445         }
446         s.truncate(len);
447         s
448     }
449 
450     #[test]
read_exact_bytes_from_slice()451     fn read_exact_bytes_from_slice() {
452         let bytes = make_long_string(100);
453         let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
454         assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
455         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
456     }
457 
458     #[test]
read_exact_bytes_from_bytes()459     fn read_exact_bytes_from_bytes() {
460         let bytes = Bytes::from(make_long_string(100));
461         let mut bri = BufReadIter::from_bytes(&bytes);
462         let read = bri.read_exact_bytes(90).unwrap();
463         assert_eq!(&bytes[..90], &read[..]);
464         assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
465         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
466     }
467 }
468 
469 #[cfg(test)]
470 mod test {
471     use std::io;
472     use std::io::BufRead;
473     use std::io::Read;
474 
475     use super::*;
476 
477     #[test]
eof_at_limit()478     fn eof_at_limit() {
479         struct Read5ThenPanic {
480             pos: usize,
481         }
482 
483         impl Read for Read5ThenPanic {
484             fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
485                 unreachable!();
486             }
487         }
488 
489         impl BufRead for Read5ThenPanic {
490             fn fill_buf(&mut self) -> io::Result<&[u8]> {
491                 assert_eq!(0, self.pos);
492                 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
493                 Ok(ZERO_TO_FIVE)
494             }
495 
496             fn consume(&mut self, amt: usize) {
497                 if amt == 0 {
498                     // drop of BufReadIter
499                     return;
500                 }
501 
502                 assert_eq!(0, self.pos);
503                 assert_eq!(5, amt);
504                 self.pos += amt;
505             }
506         }
507 
508         let mut read = Read5ThenPanic { pos: 0 };
509         let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
510         assert_eq!(0, buf_read_iter.pos());
511         let _prev_limit = buf_read_iter.push_limit(5);
512         buf_read_iter.read_byte().expect("read_byte");
513         buf_read_iter
514             .read_exact(&mut [
515                 MaybeUninit::uninit(),
516                 MaybeUninit::uninit(),
517                 MaybeUninit::uninit(),
518                 MaybeUninit::uninit(),
519             ])
520             .expect("read_exact");
521         assert!(buf_read_iter.eof().expect("eof"));
522     }
523 }
524