• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::cmp;
2 use std::io::BufRead;
3 use std::io::BufReader;
4 use std::io::Read;
5 use std::mem;
6 use std::mem::MaybeUninit;
7 
8 #[cfg(feature = "bytes")]
9 use bytes::buf::UninitSlice;
10 #[cfg(feature = "bytes")]
11 use bytes::BufMut;
12 #[cfg(feature = "bytes")]
13 use bytes::Bytes;
14 #[cfg(feature = "bytes")]
15 use bytes::BytesMut;
16 
17 use crate::coded_input_stream::buf_read_or_reader::BufReadOrReader;
18 use crate::coded_input_stream::input_buf::InputBuf;
19 use crate::coded_input_stream::input_source::InputSource;
20 use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
21 use crate::error::ProtobufError;
22 use crate::error::WireError;
23 
24 // If an input stream is constructed with a `Read`, we create a
25 // `BufReader` with an internal buffer of this size.
26 const INPUT_STREAM_BUFFER_SIZE: usize = 4096;
27 
28 const NO_LIMIT: u64 = u64::MAX;
29 
30 /// Dangerous implementation of `BufRead`.
31 ///
32 /// Unsafe wrapper around BufRead which assumes that `BufRead` buf is
33 /// not moved when `BufRead` is moved.
34 ///
35 /// This assumption is generally incorrect, however, in practice
36 /// `BufReadIter` is created either from `BufRead` reference (which
37 /// cannot  be moved, because it is locked by `CodedInputStream`) or from
38 /// `BufReader` which does not move its buffer (we know that from
39 /// inspecting rust standard library).
40 ///
41 /// It is important for `CodedInputStream` performance that small reads
42 /// (e. g. 4 bytes reads) do not involve virtual calls or switches.
43 /// This is achievable with `BufReadIter`.
44 #[derive(Debug)]
45 pub(crate) struct BufReadIter<'a> {
46     input_source: InputSource<'a>,
47     buf: InputBuf<'a>,
48     pos_of_buf_start: u64,
49     limit: u64,
50 }
51 
52 impl<'a> Drop for BufReadIter<'a> {
drop(&mut self)53     fn drop(&mut self) {
54         match self.input_source {
55             InputSource::Read(ref mut buf_read) => buf_read.consume(self.buf.pos_within_buf()),
56             _ => {}
57         }
58     }
59 }
60 
61 impl<'a> BufReadIter<'a> {
from_read(read: &'a mut dyn Read) -> BufReadIter<'a>62     pub(crate) fn from_read(read: &'a mut dyn Read) -> BufReadIter<'a> {
63         BufReadIter {
64             input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity(
65                 INPUT_STREAM_BUFFER_SIZE,
66                 read,
67             ))),
68             buf: InputBuf::empty(),
69             pos_of_buf_start: 0,
70             limit: NO_LIMIT,
71         }
72     }
73 
from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a>74     pub(crate) fn from_buf_read(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
75         BufReadIter {
76             input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)),
77             buf: InputBuf::empty(),
78             pos_of_buf_start: 0,
79             limit: NO_LIMIT,
80         }
81     }
82 
from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a>83     pub(crate) fn from_byte_slice(bytes: &'a [u8]) -> BufReadIter<'a> {
84         BufReadIter {
85             input_source: InputSource::Slice(bytes),
86             buf: InputBuf::from_bytes(bytes),
87             pos_of_buf_start: 0,
88             limit: NO_LIMIT,
89         }
90     }
91 
92     #[cfg(feature = "bytes")]
from_bytes(bytes: &'a Bytes) -> BufReadIter<'a>93     pub(crate) fn from_bytes(bytes: &'a Bytes) -> BufReadIter<'a> {
94         BufReadIter {
95             input_source: InputSource::Bytes(bytes),
96             buf: InputBuf::from_bytes(&bytes),
97             pos_of_buf_start: 0,
98             limit: NO_LIMIT,
99         }
100     }
101 
102     #[inline]
assertions(&self)103     fn assertions(&self) {
104         debug_assert!(self.pos() <= self.limit);
105         self.buf.assertions();
106     }
107 
108     #[inline(always)]
pos(&self) -> u64109     pub(crate) fn pos(&self) -> u64 {
110         self.pos_of_buf_start + self.buf.pos_within_buf() as u64
111     }
112 
113     /// Recompute `limit_within_buf` after update of `limit`
114     #[inline]
update_limit_within_buf(&mut self)115     fn update_limit_within_buf(&mut self) {
116         assert!(self.limit >= self.pos_of_buf_start);
117         self.buf.update_limit(self.limit - self.pos_of_buf_start);
118         self.assertions();
119     }
120 
push_limit(&mut self, limit: u64) -> crate::Result<u64>121     pub(crate) fn push_limit(&mut self, limit: u64) -> crate::Result<u64> {
122         let new_limit = match self.pos().checked_add(limit) {
123             Some(new_limit) => new_limit,
124             None => return Err(ProtobufError::WireError(WireError::LimitOverflow).into()),
125         };
126 
127         if new_limit > self.limit {
128             return Err(ProtobufError::WireError(WireError::LimitIncrease).into());
129         }
130 
131         let prev_limit = mem::replace(&mut self.limit, new_limit);
132 
133         self.update_limit_within_buf();
134 
135         Ok(prev_limit)
136     }
137 
138     #[inline]
pop_limit(&mut self, limit: u64)139     pub(crate) fn pop_limit(&mut self, limit: u64) {
140         assert!(limit >= self.limit);
141 
142         self.limit = limit;
143 
144         self.update_limit_within_buf();
145     }
146 
147     #[inline(always)]
remaining_in_buf(&self) -> &[u8]148     pub(crate) fn remaining_in_buf(&self) -> &[u8] {
149         self.buf.remaining_in_buf()
150     }
151 
152     #[inline]
consume(&mut self, amt: usize)153     pub(crate) fn consume(&mut self, amt: usize) {
154         self.buf.consume(amt);
155     }
156 
157     #[inline(always)]
remaining_in_buf_len(&self) -> usize158     pub(crate) fn remaining_in_buf_len(&self) -> usize {
159         self.remaining_in_buf().len()
160     }
161 
162     #[inline(always)]
bytes_until_limit(&self) -> u64163     pub(crate) fn bytes_until_limit(&self) -> u64 {
164         if self.limit == NO_LIMIT {
165             NO_LIMIT
166         } else {
167             self.limit - self.pos()
168         }
169     }
170 
171     #[inline(always)]
eof(&mut self) -> crate::Result<bool>172     pub(crate) fn eof(&mut self) -> crate::Result<bool> {
173         if self.remaining_in_buf_len() != 0 {
174             Ok(false)
175         } else {
176             Ok(self.fill_buf()?.is_empty())
177         }
178     }
179 
read_byte_slow(&mut self) -> crate::Result<u8>180     fn read_byte_slow(&mut self) -> crate::Result<u8> {
181         self.fill_buf_slow()?;
182 
183         if let Some(b) = self.buf.read_byte() {
184             return Ok(b);
185         }
186 
187         Err(WireError::UnexpectedEof.into())
188     }
189 
190     #[inline(always)]
read_byte(&mut self) -> crate::Result<u8>191     pub(crate) fn read_byte(&mut self) -> crate::Result<u8> {
192         if let Some(b) = self.buf.read_byte() {
193             return Ok(b);
194         }
195 
196         self.read_byte_slow()
197     }
198 
199     #[cfg(feature = "bytes")]
read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes>200     pub(crate) fn read_exact_bytes(&mut self, len: usize) -> crate::Result<Bytes> {
201         if let InputSource::Bytes(bytes) = self.input_source {
202             if len > self.remaining_in_buf_len() {
203                 return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
204             }
205             let end = self.buf.pos_within_buf() + len;
206 
207             let r = bytes.slice(self.buf.pos_within_buf()..end);
208             self.buf.consume(len);
209             Ok(r)
210         } else {
211             if len >= READ_RAW_BYTES_MAX_ALLOC {
212                 // We cannot trust `len` because protobuf message could be malformed.
213                 // Reading should not result in OOM when allocating a buffer.
214                 let mut v = Vec::new();
215                 self.read_exact_to_vec(len, &mut v)?;
216                 Ok(Bytes::from(v))
217             } else {
218                 let mut r = BytesMut::with_capacity(len);
219                 unsafe {
220                     let buf = Self::uninit_slice_as_mut_slice(&mut r.chunk_mut()[..len]);
221                     self.read_exact(buf)?;
222                     r.advance_mut(len);
223                 }
224                 Ok(r.freeze())
225             }
226         }
227     }
228 
229     #[cfg(feature = "bytes")]
uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>]230     unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] {
231         use std::slice;
232         slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len())
233     }
234 
235     /// Returns 0 when EOF or limit reached.
read(&mut self, buf: &mut [u8]) -> crate::Result<usize>236     pub(crate) fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> {
237         let rem = self.fill_buf()?;
238 
239         let len = cmp::min(rem.len(), buf.len());
240         buf[..len].copy_from_slice(&rem[..len]);
241         self.buf.consume(len);
242         Ok(len)
243     }
244 
consume_buf(&mut self) -> crate::Result<()>245     fn consume_buf(&mut self) -> crate::Result<()> {
246         match &mut self.input_source {
247             InputSource::Read(read) => {
248                 read.consume(self.buf.pos_within_buf());
249                 self.pos_of_buf_start += self.buf.pos_within_buf() as u64;
250                 self.buf = InputBuf::empty();
251                 self.assertions();
252                 Ok(())
253             }
254             _ => Err(WireError::UnexpectedEof.into()),
255         }
256     }
257 
258     /// Read at most `max` bytes.
259     ///
260     /// Returns 0 when EOF or limit reached.
read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize>261     fn read_to_vec(&mut self, vec: &mut Vec<u8>, max: usize) -> crate::Result<usize> {
262         let rem = self.fill_buf()?;
263 
264         let len = cmp::min(rem.len(), max);
265         vec.extend_from_slice(&rem[..len]);
266         self.buf.consume(len);
267         Ok(len)
268     }
269 
read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()>270     fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
271         if self.bytes_until_limit() < buf.len() as u64 {
272             return Err(ProtobufError::WireError(WireError::UnexpectedEof).into());
273         }
274 
275         self.consume_buf()?;
276 
277         match &mut self.input_source {
278             InputSource::Read(buf_read) => {
279                 buf_read.read_exact_uninit(buf)?;
280                 self.pos_of_buf_start += buf.len() as u64;
281                 self.assertions();
282                 Ok(())
283             }
284             _ => unreachable!(),
285         }
286     }
287 
288     #[inline]
read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()>289     pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> crate::Result<()> {
290         if self.remaining_in_buf_len() >= buf.len() {
291             self.buf.read_bytes(buf);
292             return Ok(());
293         }
294 
295         self.read_exact_slow(buf)
296     }
297 
298     /// Read raw bytes into the supplied vector.  The vector will be resized as needed and
299     /// overwritten.
read_exact_to_vec( &mut self, count: usize, target: &mut Vec<u8>, ) -> crate::Result<()>300     pub(crate) fn read_exact_to_vec(
301         &mut self,
302         count: usize,
303         target: &mut Vec<u8>,
304     ) -> crate::Result<()> {
305         // TODO: also do some limits when reading from unlimited source
306         if count as u64 > self.bytes_until_limit() {
307             return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
308         }
309 
310         target.clear();
311 
312         if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
313             // avoid calling `reserve` on buf with very large buffer: could be a malformed message
314 
315             target.reserve(READ_RAW_BYTES_MAX_ALLOC);
316 
317             while target.len() < count {
318                 if count - target.len() <= target.len() {
319                     target.reserve_exact(count - target.len());
320                 } else {
321                     target.reserve(1);
322                 }
323 
324                 let max = cmp::min(target.capacity() - target.len(), count - target.len());
325                 let read = self.read_to_vec(target, max)?;
326                 if read == 0 {
327                     return Err(ProtobufError::WireError(WireError::TruncatedMessage).into());
328                 }
329             }
330         } else {
331             target.reserve_exact(count);
332 
333             unsafe {
334                 self.read_exact(&mut target.spare_capacity_mut()[..count])?;
335                 target.set_len(count);
336             }
337         }
338 
339         debug_assert_eq!(count, target.len());
340 
341         Ok(())
342     }
343 
skip_bytes(&mut self, count: u32) -> crate::Result<()>344     pub(crate) fn skip_bytes(&mut self, count: u32) -> crate::Result<()> {
345         if count as usize <= self.remaining_in_buf_len() {
346             self.buf.consume(count as usize);
347             return Ok(());
348         }
349 
350         if count as u64 > self.bytes_until_limit() {
351             return Err(WireError::TruncatedMessage.into());
352         }
353 
354         self.consume_buf()?;
355 
356         match &mut self.input_source {
357             InputSource::Read(read) => {
358                 read.skip_bytes(count as usize)?;
359                 self.pos_of_buf_start += count as u64;
360                 self.assertions();
361                 Ok(())
362             }
363             _ => unreachable!(),
364         }
365     }
366 
fill_buf_slow(&mut self) -> crate::Result<()>367     fn fill_buf_slow(&mut self) -> crate::Result<()> {
368         self.assertions();
369         if self.limit == self.pos() {
370             return Ok(());
371         }
372 
373         match self.input_source {
374             InputSource::Read(..) => {}
375             _ => return Ok(()),
376         }
377 
378         self.consume_buf()?;
379 
380         match self.input_source {
381             InputSource::Read(ref mut buf_read) => {
382                 self.buf = unsafe { InputBuf::from_bytes_ignore_lifetime(buf_read.fill_buf()?) };
383                 self.update_limit_within_buf();
384                 Ok(())
385             }
386             _ => {
387                 unreachable!();
388             }
389         }
390     }
391 
392     #[inline(always)]
fill_buf(&mut self) -> crate::Result<&[u8]>393     pub(crate) fn fill_buf(&mut self) -> crate::Result<&[u8]> {
394         let rem = self.buf.remaining_in_buf();
395         if !rem.is_empty() {
396             return Ok(rem);
397         }
398 
399         if self.limit == self.pos() {
400             return Ok(&[]);
401         }
402 
403         self.fill_buf_slow()?;
404 
405         Ok(self.buf.remaining_in_buf())
406     }
407 }
408 
409 #[cfg(all(test, feature = "bytes"))]
410 mod test_bytes {
411     use std::io::Write;
412 
413     use super::*;
414 
make_long_string(len: usize) -> Vec<u8>415     fn make_long_string(len: usize) -> Vec<u8> {
416         let mut s = Vec::new();
417         while s.len() < len {
418             let len = s.len();
419             write!(&mut s, "{}", len).expect("unexpected");
420         }
421         s.truncate(len);
422         s
423     }
424 
425     #[test]
426     #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522
read_exact_bytes_from_slice()427     fn read_exact_bytes_from_slice() {
428         let bytes = make_long_string(100);
429         let mut bri = BufReadIter::from_byte_slice(&bytes[..]);
430         assert_eq!(&bytes[..90], &bri.read_exact_bytes(90).unwrap()[..]);
431         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
432     }
433 
434     #[test]
435     #[cfg_attr(miri, ignore)] // bytes violates SB, see https://github.com/tokio-rs/bytes/issues/522
read_exact_bytes_from_bytes()436     fn read_exact_bytes_from_bytes() {
437         let bytes = Bytes::from(make_long_string(100));
438         let mut bri = BufReadIter::from_bytes(&bytes);
439         let read = bri.read_exact_bytes(90).unwrap();
440         assert_eq!(&bytes[..90], &read[..]);
441         assert_eq!(&bytes[..90].as_ptr(), &read.as_ptr());
442         assert_eq!(bytes[90], bri.read_byte().expect("read_byte"));
443     }
444 }
445 
446 #[cfg(test)]
447 mod test {
448     use std::io;
449     use std::io::BufRead;
450     use std::io::Read;
451 
452     use super::*;
453 
454     #[test]
eof_at_limit()455     fn eof_at_limit() {
456         struct Read5ThenPanic {
457             pos: usize,
458         }
459 
460         impl Read for Read5ThenPanic {
461             fn read(&mut self, _buf: &mut [u8]) -> io::Result<usize> {
462                 unreachable!();
463             }
464         }
465 
466         impl BufRead for Read5ThenPanic {
467             fn fill_buf(&mut self) -> io::Result<&[u8]> {
468                 assert_eq!(0, self.pos);
469                 static ZERO_TO_FIVE: &'static [u8] = &[0, 1, 2, 3, 4];
470                 Ok(ZERO_TO_FIVE)
471             }
472 
473             fn consume(&mut self, amt: usize) {
474                 if amt == 0 {
475                     // drop of BufReadIter
476                     return;
477                 }
478 
479                 assert_eq!(0, self.pos);
480                 assert_eq!(5, amt);
481                 self.pos += amt;
482             }
483         }
484 
485         let mut read = Read5ThenPanic { pos: 0 };
486         let mut buf_read_iter = BufReadIter::from_buf_read(&mut read);
487         assert_eq!(0, buf_read_iter.pos());
488         let _prev_limit = buf_read_iter.push_limit(5);
489         buf_read_iter.read_byte().expect("read_byte");
490         buf_read_iter
491             .read_exact(&mut [
492                 MaybeUninit::uninit(),
493                 MaybeUninit::uninit(),
494                 MaybeUninit::uninit(),
495                 MaybeUninit::uninit(),
496             ])
497             .expect("read_exact");
498         assert!(buf_read_iter.eof().expect("eof"));
499     }
500 }
501