• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::follow::Follow;
2 use crate::{ForwardsUOffset, SOffsetT, SkipSizePrefix, UOffsetT, VOffsetT, Vector, SIZE_UOFFSET};
3 use std::ops::Range;
4 use thiserror::Error;
5 
6 /// Traces the location of data errors. Not populated for Dos detecting errors.
7 /// Useful for MissingRequiredField and Utf8Error in particular, though
8 /// the other errors should not be producible by correct flatbuffers implementations.
9 #[derive(Clone, Debug, PartialEq, Eq)]
10 pub enum ErrorTraceDetail {
11     VectorElement {
12         index: usize,
13         position: usize,
14     },
15     TableField {
16         field_name: &'static str,
17         position: usize,
18     },
19     UnionVariant {
20         variant: &'static str,
21         position: usize,
22     },
23 }
24 #[derive(PartialEq, Eq, Default, Debug, Clone)]
25 pub struct ErrorTrace(Vec<ErrorTraceDetail>);
26 impl std::convert::AsRef<[ErrorTraceDetail]> for ErrorTrace {
27     #[inline]
as_ref(&self) -> &[ErrorTraceDetail]28     fn as_ref(&self) -> &[ErrorTraceDetail] {
29         &self.0
30     }
31 }
32 
33 /// Describes how a flatuffer is invalid and, for data errors, roughly where. No extra tracing
34 /// information is given for DoS detecting errors since it will probably be a lot.
35 #[derive(Clone, Error, Debug, PartialEq, Eq)]
36 pub enum InvalidFlatbuffer {
37     #[error("Missing required field `{required}`.\n{error_trace}")]
38     MissingRequiredField {
39         required: &'static str,
40         error_trace: ErrorTrace,
41     },
42     #[error(
43         "Union exactly one of union discriminant (`{field_type}`) and value \
44              (`{field}`) are present.\n{error_trace}"
45     )]
46     InconsistentUnion {
47         field: &'static str,
48         field_type: &'static str,
49         error_trace: ErrorTrace,
50     },
51     #[error("Utf8 error for string in {range:?}: {error}\n{error_trace}")]
52     Utf8Error {
53         #[source]
54         error: std::str::Utf8Error,
55         range: Range<usize>,
56         error_trace: ErrorTrace,
57     },
58     #[error("String in range [{}, {}) is missing its null terminator.\n{error_trace}",
59             range.start, range.end)]
60     MissingNullTerminator {
61         range: Range<usize>,
62         error_trace: ErrorTrace,
63     },
64     #[error("Type `{unaligned_type}` at position {position} is unaligned.\n{error_trace}")]
65     Unaligned {
66         position: usize,
67         unaligned_type: &'static str,
68         error_trace: ErrorTrace,
69     },
70     #[error("Range [{}, {}) is out of bounds.\n{error_trace}", range.start, range.end)]
71     RangeOutOfBounds {
72         range: Range<usize>,
73         error_trace: ErrorTrace,
74     },
75     #[error(
76         "Signed offset at position {position} has value {soffset} which points out of bounds.\
77              \n{error_trace}"
78     )]
79     SignedOffsetOutOfBounds {
80         soffset: SOffsetT,
81         position: usize,
82         error_trace: ErrorTrace,
83     },
84     // Dos detecting errors. These do not get error traces since it will probably be very large.
85     #[error("Too many tables.")]
86     TooManyTables,
87     #[error("Apparent size too large.")]
88     ApparentSizeTooLarge,
89     #[error("Nested table depth limit reached.")]
90     DepthLimitReached,
91 }
92 
93 impl std::fmt::Display for ErrorTrace {
fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result94     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
95         use ErrorTraceDetail::*;
96         for e in self.0.iter() {
97             match e {
98                 VectorElement { index, position } => {
99                     writeln!(
100                         f,
101                         "\twhile verifying vector element {:?} at position {:?}",
102                         index, position
103                     )?;
104                 }
105                 TableField {
106                     field_name,
107                     position,
108                 } => {
109                     writeln!(
110                         f,
111                         "\twhile verifying table field `{}` at position {:?}",
112                         field_name, position
113                     )?;
114                 }
115                 UnionVariant { variant, position } => {
116                     writeln!(
117                         f,
118                         "\t while verifying union variant `{}` at position {:?}",
119                         variant, position
120                     )?;
121                 }
122             }
123         }
124         Ok(())
125     }
126 }
127 
128 pub type Result<T> = std::prelude::v1::Result<T, InvalidFlatbuffer>;
129 
130 impl InvalidFlatbuffer {
new_range_oob<T>(start: usize, end: usize) -> Result<T>131     fn new_range_oob<T>(start: usize, end: usize) -> Result<T> {
132         Err(Self::RangeOutOfBounds {
133             range: Range { start, end },
134             error_trace: Default::default(),
135         })
136     }
new_inconsistent_union<T>(field: &'static str, field_type: &'static str) -> Result<T>137     fn new_inconsistent_union<T>(field: &'static str, field_type: &'static str) -> Result<T> {
138         Err(Self::InconsistentUnion {
139             field,
140             field_type,
141             error_trace: Default::default(),
142         })
143     }
new_missing_required<T>(required: &'static str) -> Result<T>144     fn new_missing_required<T>(required: &'static str) -> Result<T> {
145         Err(Self::MissingRequiredField {
146             required,
147             error_trace: Default::default(),
148         })
149     }
150 }
151 
152 /// Records the path to the verifier detail if the error is a data error and not a DoS error.
append_trace<T>(mut res: Result<T>, d: ErrorTraceDetail) -> Result<T>153 fn append_trace<T>(mut res: Result<T>, d: ErrorTraceDetail) -> Result<T> {
154     if let Err(e) = res.as_mut() {
155         use InvalidFlatbuffer::*;
156         if let MissingRequiredField { error_trace, .. }
157         | Unaligned { error_trace, .. }
158         | RangeOutOfBounds { error_trace, .. }
159         | InconsistentUnion { error_trace, .. }
160         | Utf8Error { error_trace, .. }
161         | MissingNullTerminator { error_trace, .. }
162         | SignedOffsetOutOfBounds { error_trace, .. } = e
163         {
164             error_trace.0.push(d)
165         }
166     }
167     res
168 }
169 
170 /// Adds a TableField trace detail if `res` is a data error.
trace_field<T>(res: Result<T>, field_name: &'static str, position: usize) -> Result<T>171 fn trace_field<T>(res: Result<T>, field_name: &'static str, position: usize) -> Result<T> {
172     append_trace(
173         res,
174         ErrorTraceDetail::TableField {
175             field_name,
176             position,
177         },
178     )
179 }
180 /// Adds a TableField trace detail if `res` is a data error.
trace_elem<T>(res: Result<T>, index: usize, position: usize) -> Result<T>181 fn trace_elem<T>(res: Result<T>, index: usize, position: usize) -> Result<T> {
182     append_trace(res, ErrorTraceDetail::VectorElement { index, position })
183 }
184 
185 #[derive(Debug, Clone, PartialEq, Eq)]
186 pub struct VerifierOptions {
187     /// Maximum depth of nested tables allowed in a valid flatbuffer.
188     pub max_depth: usize,
189     /// Maximum number of tables allowed in a valid flatbuffer.
190     pub max_tables: usize,
191     /// Maximum "apparent" size of the message if the Flatbuffer object DAG is expanded into a
192     /// tree.
193     pub max_apparent_size: usize,
194     /// Ignore errors where a string is missing its null terminator.
195     /// This is mostly a problem if the message will be sent to a client using old c-strings.
196     pub ignore_missing_null_terminator: bool,
197     // probably want an option to ignore utf8 errors since strings come from c++
198     // options to error un-recognized enums and unions? possible footgun.
199     // Ignore nested flatbuffers, etc?
200 }
201 impl Default for VerifierOptions {
default() -> Self202     fn default() -> Self {
203         Self {
204             max_depth: 64,
205             max_tables: 1_000_000,
206             // size_ might do something different.
207             max_apparent_size: 1 << 31,
208             ignore_missing_null_terminator: false,
209         }
210     }
211 }
212 
213 /// Carries the verification state. Should not be reused between tables.
214 #[derive(Debug)]
215 pub struct Verifier<'opts, 'buf> {
216     buffer: &'buf [u8],
217     opts: &'opts VerifierOptions,
218     depth: usize,
219     num_tables: usize,
220     apparent_size: usize,
221 }
222 impl<'opts, 'buf> Verifier<'opts, 'buf> {
new(opts: &'opts VerifierOptions, buffer: &'buf [u8]) -> Self223     pub fn new(opts: &'opts VerifierOptions, buffer: &'buf [u8]) -> Self {
224         Self {
225             opts,
226             buffer,
227             depth: 0,
228             num_tables: 0,
229             apparent_size: 0,
230         }
231     }
232     /// Resets verifier internal state.
233     #[inline]
reset(&mut self)234     pub fn reset(&mut self) {
235         self.depth = 0;
236         self.num_tables = 0;
237         self.num_tables = 0;
238     }
239     /// Checks `pos` is aligned to T's alignment. This does not mean `buffer[pos]` is aligned w.r.t
240     /// memory since `buffer: &[u8]` has alignment 1.
241     ///
242     /// ### WARNING
243     /// This does not work for flatbuffers-structs as they have alignment 1 according to
244     /// `core::mem::align_of` but are meant to have higher alignment within a Flatbuffer w.r.t.
245     /// `buffer[0]`. TODO(caspern).
246     #[inline]
is_aligned<T>(&self, pos: usize) -> Result<()>247     fn is_aligned<T>(&self, pos: usize) -> Result<()> {
248         if pos % std::mem::align_of::<T>() == 0 {
249             Ok(())
250         } else {
251             Err(InvalidFlatbuffer::Unaligned {
252                 unaligned_type: std::any::type_name::<T>(),
253                 position: pos,
254                 error_trace: Default::default(),
255             })
256         }
257     }
258     #[inline]
range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()>259     fn range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()> {
260         let end = pos.saturating_add(size);
261         if end > self.buffer.len() {
262             return InvalidFlatbuffer::new_range_oob(pos, end);
263         }
264         self.apparent_size += size;
265         if self.apparent_size > self.opts.max_apparent_size {
266             return Err(InvalidFlatbuffer::ApparentSizeTooLarge);
267         }
268         Ok(())
269     }
270     /// Check that there really is a T in there.
271     #[inline]
in_buffer<T>(&mut self, pos: usize) -> Result<()>272     pub fn in_buffer<T>(&mut self, pos: usize) -> Result<()> {
273         self.is_aligned::<T>(pos)?;
274         self.range_in_buffer(pos, std::mem::size_of::<T>())
275     }
276     #[inline]
get_u16(&mut self, pos: usize) -> Result<u16>277     fn get_u16(&mut self, pos: usize) -> Result<u16> {
278         self.in_buffer::<u16>(pos)?;
279         Ok(u16::from_le_bytes([self.buffer[pos], self.buffer[pos + 1]]))
280     }
281     #[inline]
get_uoffset(&mut self, pos: usize) -> Result<UOffsetT>282     fn get_uoffset(&mut self, pos: usize) -> Result<UOffsetT> {
283         self.in_buffer::<u32>(pos)?;
284         Ok(u32::from_le_bytes([
285             self.buffer[pos],
286             self.buffer[pos + 1],
287             self.buffer[pos + 2],
288             self.buffer[pos + 3],
289         ]))
290     }
291     #[inline]
deref_soffset(&mut self, pos: usize) -> Result<usize>292     fn deref_soffset(&mut self, pos: usize) -> Result<usize> {
293         self.in_buffer::<SOffsetT>(pos)?;
294         let offset = SOffsetT::from_le_bytes([
295             self.buffer[pos],
296             self.buffer[pos + 1],
297             self.buffer[pos + 2],
298             self.buffer[pos + 3],
299         ]);
300 
301         // signed offsets are subtracted.
302         let derefed = if offset > 0 {
303             pos.checked_sub(offset.abs() as usize)
304         } else {
305             pos.checked_add(offset.abs() as usize)
306         };
307         if let Some(x) = derefed {
308             if x < self.buffer.len() {
309                 return Ok(x);
310             }
311         }
312         Err(InvalidFlatbuffer::SignedOffsetOutOfBounds {
313             soffset: offset,
314             position: pos,
315             error_trace: Default::default(),
316         })
317     }
318     #[inline]
visit_table<'ver>( &'ver mut self, table_pos: usize, ) -> Result<TableVerifier<'ver, 'opts, 'buf>>319     pub fn visit_table<'ver>(
320         &'ver mut self,
321         table_pos: usize,
322     ) -> Result<TableVerifier<'ver, 'opts, 'buf>> {
323         let vtable_pos = self.deref_soffset(table_pos)?;
324         let vtable_len = self.get_u16(vtable_pos)? as usize;
325         self.is_aligned::<VOffsetT>(vtable_pos.saturating_add(vtable_len))?; // i.e. vtable_len is even.
326         self.range_in_buffer(vtable_pos, vtable_len)?;
327         // Check bounds.
328         self.num_tables += 1;
329         if self.num_tables > self.opts.max_tables {
330             return Err(InvalidFlatbuffer::TooManyTables);
331         }
332         self.depth += 1;
333         if self.depth > self.opts.max_depth {
334             return Err(InvalidFlatbuffer::DepthLimitReached);
335         }
336         Ok(TableVerifier {
337             pos: table_pos,
338             vtable: vtable_pos,
339             vtable_len,
340             verifier: self,
341         })
342     }
343 
344     /// Runs the union variant's type's verifier assuming the variant is at the given position,
345     /// tracing the error.
verify_union_variant<T: Verifiable>( &mut self, variant: &'static str, position: usize, ) -> Result<()>346     pub fn verify_union_variant<T: Verifiable>(
347         &mut self,
348         variant: &'static str,
349         position: usize,
350     ) -> Result<()> {
351         let res = T::run_verifier(self, position);
352         append_trace(res, ErrorTraceDetail::UnionVariant { variant, position })
353     }
354 }
355 
356 // Cache table metadata in usize so we don't have to cast types or jump around so much.
357 // We will visit every field anyway.
358 pub struct TableVerifier<'ver, 'opts, 'buf> {
359     // Absolute position of table in buffer
360     pos: usize,
361     // Absolute position of vtable in buffer.
362     vtable: usize,
363     // Length of vtable.
364     vtable_len: usize,
365     // Verifier struct which holds the surrounding state and options.
366     verifier: &'ver mut Verifier<'opts, 'buf>,
367 }
368 impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
deref(&mut self, field: VOffsetT) -> Result<Option<usize>>369     fn deref(&mut self, field: VOffsetT) -> Result<Option<usize>> {
370         let field = field as usize;
371         if field < self.vtable_len {
372             let field_offset = self.verifier.get_u16(self.vtable.saturating_add(field))?;
373             if field_offset > 0 {
374                 // Field is present.
375                 let field_pos = self.pos.saturating_add(field_offset as usize);
376                 return Ok(Some(field_pos));
377             }
378         }
379         Ok(None)
380     }
381 
382     #[inline]
visit_field<T: Verifiable>( mut self, field_name: &'static str, field: VOffsetT, required: bool, ) -> Result<Self>383     pub fn visit_field<T: Verifiable>(
384         mut self,
385         field_name: &'static str,
386         field: VOffsetT,
387         required: bool,
388     ) -> Result<Self> {
389         if let Some(field_pos) = self.deref(field)? {
390             trace_field(
391                 T::run_verifier(self.verifier, field_pos),
392                 field_name,
393                 field_pos,
394             )?;
395             return Ok(self);
396         }
397         if required {
398             InvalidFlatbuffer::new_missing_required(field_name)
399         } else {
400             Ok(self)
401         }
402     }
403     #[inline]
404     /// Union verification is complicated. The schemas passes this function the metadata of the
405     /// union's key (discriminant) and value fields, and a callback. The function verifies and
406     /// reads the key, then invokes the callback to perform data-dependent verification.
visit_union<Key, UnionVerifier>( mut self, key_field_name: &'static str, key_field_voff: VOffsetT, val_field_name: &'static str, val_field_voff: VOffsetT, required: bool, verify_union: UnionVerifier, ) -> Result<Self> where Key: Follow<'buf> + Verifiable, UnionVerifier: (std::ops::FnOnce(<Key as Follow<'buf>>::Inner, &mut Verifier, usize) -> Result<()>),407     pub fn visit_union<Key, UnionVerifier>(
408         mut self,
409         key_field_name: &'static str,
410         key_field_voff: VOffsetT,
411         val_field_name: &'static str,
412         val_field_voff: VOffsetT,
413         required: bool,
414         verify_union: UnionVerifier,
415     ) -> Result<Self>
416     where
417         Key: Follow<'buf> + Verifiable,
418         UnionVerifier:
419             (std::ops::FnOnce(<Key as Follow<'buf>>::Inner, &mut Verifier, usize) -> Result<()>),
420         // NOTE: <Key as Follow<'buf>>::Inner == Key
421     {
422         // TODO(caspern): how to trace vtable errors?
423         let val_pos = self.deref(val_field_voff)?;
424         let key_pos = self.deref(key_field_voff)?;
425         match (key_pos, val_pos) {
426             (None, None) => {
427                 if required {
428                     InvalidFlatbuffer::new_missing_required(val_field_name)
429                 } else {
430                     Ok(self)
431                 }
432             }
433             (Some(k), Some(v)) => {
434                 trace_field(Key::run_verifier(self.verifier, k), key_field_name, k)?;
435                 let discriminant = Key::follow(self.verifier.buffer, k);
436                 trace_field(
437                     verify_union(discriminant, self.verifier, v),
438                     val_field_name,
439                     v,
440                 )?;
441                 Ok(self)
442             }
443             _ => InvalidFlatbuffer::new_inconsistent_union(key_field_name, val_field_name),
444         }
445     }
finish(self) -> &'ver mut Verifier<'opts, 'buf>446     pub fn finish(self) -> &'ver mut Verifier<'opts, 'buf> {
447         self.verifier.depth -= 1;
448         self.verifier
449     }
450 }
451 
452 // Needs to be implemented for Tables and maybe structs.
453 // Unions need some special treatment.
454 pub trait Verifiable {
455     /// Runs the verifier for this type, assuming its at position `pos` in the verifier's buffer.
456     /// Should not need to be called directly.
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>457     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()>;
458 }
459 
460 // Verify the uoffset and then pass verifier to the type being pointed to.
461 impl<T: Verifiable> Verifiable for ForwardsUOffset<T> {
462     #[inline]
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>463     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()> {
464         let offset = v.get_uoffset(pos)? as usize;
465         let next_pos = offset.saturating_add(pos);
466         T::run_verifier(v, next_pos)
467     }
468 }
469 
470 /// Checks and returns the range containing the flatbuffers vector.
verify_vector_range<T>(v: &mut Verifier, pos: usize) -> Result<std::ops::Range<usize>>471 fn verify_vector_range<T>(v: &mut Verifier, pos: usize) -> Result<std::ops::Range<usize>> {
472     let len = v.get_uoffset(pos)? as usize;
473     let start = pos.saturating_add(SIZE_UOFFSET);
474     v.is_aligned::<T>(start)?;
475     let size = len.saturating_mul(std::mem::size_of::<T>());
476     let end = start.saturating_add(size);
477     v.range_in_buffer(start, size)?;
478     Ok(std::ops::Range { start, end })
479 }
480 
481 pub trait SimpleToVerifyInSlice {}
482 impl SimpleToVerifyInSlice for bool {}
483 impl SimpleToVerifyInSlice for i8 {}
484 impl SimpleToVerifyInSlice for u8 {}
485 impl SimpleToVerifyInSlice for i16 {}
486 impl SimpleToVerifyInSlice for u16 {}
487 impl SimpleToVerifyInSlice for i32 {}
488 impl SimpleToVerifyInSlice for u32 {}
489 impl SimpleToVerifyInSlice for f32 {}
490 impl SimpleToVerifyInSlice for i64 {}
491 impl SimpleToVerifyInSlice for u64 {}
492 impl SimpleToVerifyInSlice for f64 {}
493 
494 impl<T: SimpleToVerifyInSlice> Verifiable for Vector<'_, T> {
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>495     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()> {
496         verify_vector_range::<T>(v, pos)?;
497         Ok(())
498     }
499 }
500 
501 impl<T: Verifiable> Verifiable for SkipSizePrefix<T> {
502     #[inline]
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>503     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()> {
504         T::run_verifier(v, pos.saturating_add(crate::SIZE_SIZEPREFIX))
505     }
506 }
507 
508 impl<T: Verifiable> Verifiable for Vector<'_, ForwardsUOffset<T>> {
509     #[inline]
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>510     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()> {
511         let range = verify_vector_range::<ForwardsUOffset<T>>(v, pos)?;
512         let size = std::mem::size_of::<ForwardsUOffset<T>>();
513         for (i, element_pos) in range.step_by(size).enumerate() {
514             trace_elem(
515                 <ForwardsUOffset<T>>::run_verifier(v, element_pos),
516                 i,
517                 element_pos,
518             )?;
519         }
520         Ok(())
521     }
522 }
523 
524 impl<'a> Verifiable for &'a str {
525     #[inline]
run_verifier(v: &mut Verifier, pos: usize) -> Result<()>526     fn run_verifier(v: &mut Verifier, pos: usize) -> Result<()> {
527         let range = verify_vector_range::<u8>(v, pos)?;
528         let has_null_terminator = v.buffer.get(range.end).map(|&b| b == 0).unwrap_or(false);
529         let s = std::str::from_utf8(&v.buffer[range.clone()]);
530         if let Err(error) = s {
531             return Err(InvalidFlatbuffer::Utf8Error {
532                 error,
533                 range,
534                 error_trace: Default::default(),
535             });
536         }
537         if !v.opts.ignore_missing_null_terminator && !has_null_terminator {
538             return Err(InvalidFlatbuffer::MissingNullTerminator {
539                 range,
540                 error_trace: Default::default(),
541             });
542         }
543         Ok(())
544     }
545 }
546 
547 // Verify VectorOfTables, Unions, Arrays, Structs...
548 macro_rules! impl_verifiable_for {
549     ($T: ty) => {
550         impl Verifiable for $T {
551             #[inline]
552             fn run_verifier<'opts, 'buf>(v: &mut Verifier<'opts, 'buf>, pos: usize) -> Result<()> {
553                 v.in_buffer::<$T>(pos)
554             }
555         }
556     };
557 }
558 impl_verifiable_for!(bool);
559 impl_verifiable_for!(u8);
560 impl_verifiable_for!(i8);
561 impl_verifiable_for!(u16);
562 impl_verifiable_for!(i16);
563 impl_verifiable_for!(u32);
564 impl_verifiable_for!(i32);
565 impl_verifiable_for!(f32);
566 impl_verifiable_for!(u64);
567 impl_verifiable_for!(i64);
568 impl_verifiable_for!(f64);
569