• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2018 Google Inc. All rights reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 mod reflection_generated;
18 mod reflection_verifier;
19 mod safe_buffer;
20 mod r#struct;
21 pub use crate::r#struct::Struct;
22 pub use crate::reflection_generated::reflection;
23 pub use crate::safe_buffer::SafeBuffer;
24 
25 use flatbuffers::{
26     emplace_scalar, read_scalar, EndianScalar, Follow, ForwardsUOffset, InvalidFlatbuffer,
27     SOffsetT, Table, UOffsetT, VOffsetT, Vector, SIZE_SOFFSET, SIZE_UOFFSET,
28 };
29 use reflection_generated::reflection::{BaseType, Field, Object, Schema};
30 
31 use core::mem::size_of;
32 use num_traits::float::Float;
33 use num_traits::int::PrimInt;
34 use num_traits::FromPrimitive;
35 use thiserror::Error;
36 
37 #[derive(Error, Debug, PartialEq)]
38 pub enum FlatbufferError {
39     #[error(transparent)]
40     VerificationError(#[from] flatbuffers::InvalidFlatbuffer),
41     #[error("Failed to convert between data type {0} and field type {1}")]
42     FieldTypeMismatch(String, String),
43     #[error("Set field value not supported for non-populated or non-scalar fields")]
44     SetValueNotSupported,
45     #[error(transparent)]
46     ParseFloatError(#[from] std::num::ParseFloatError),
47     #[error(transparent)]
48     TryFromIntError(#[from] std::num::TryFromIntError),
49     #[error("Couldn't set string because cache vector is polluted")]
50     SetStringPolluted,
51     #[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")]
52     InvalidSchema,
53     #[error("Type not supported: {0}")]
54     TypeNotSupported(String),
55     #[error("No type or invalid type found in union enum")]
56     InvalidUnionEnum,
57     #[error("Table or Struct doesn't belong to the buffer")]
58     InvalidTableOrStruct,
59     #[error("Field not found in the table schema")]
60     FieldNotFound,
61 }
62 
63 pub type FlatbufferResult<T, E = FlatbufferError> = core::result::Result<T, E>;
64 
65 /// Gets the root table from a trusted Flatbuffer.
66 ///
67 /// # Safety
68 ///
69 /// Flatbuffers accessors do not perform validation checks before accessing. Users
70 /// must trust [data] contains a valid flatbuffer. Reading unchecked buffers may cause panics or even UB.
get_any_root(data: &[u8]) -> Table71 pub unsafe fn get_any_root(data: &[u8]) -> Table {
72     <ForwardsUOffset<Table>>::follow(data, 0)
73 }
74 
75 /// Gets an integer table field given its exact type. Returns default integer value if the field is not set. Returns [None] if no default value is found. Returns error if the type size doesn't match.
76 ///
77 /// # Safety
78 ///
79 /// The value of the corresponding slot must have type T
get_field_integer<T: for<'a> Follow<'a, Inner = T> + PrimInt + FromPrimitive>( table: &Table, field: &Field, ) -> FlatbufferResult<Option<T>>80 pub unsafe fn get_field_integer<T: for<'a> Follow<'a, Inner = T> + PrimInt + FromPrimitive>(
81     table: &Table,
82     field: &Field,
83 ) -> FlatbufferResult<Option<T>> {
84     if size_of::<T>() != get_type_size(field.type_().base_type()) {
85         return Err(FlatbufferError::FieldTypeMismatch(
86             std::any::type_name::<T>().to_string(),
87             field
88                 .type_()
89                 .base_type()
90                 .variant_name()
91                 .unwrap_or_default()
92                 .to_string(),
93         ));
94     }
95 
96     let default = T::from_i64(field.default_integer());
97     Ok(table.get::<T>(field.offset(), default))
98 }
99 
100 /// Gets a floating point table field given its exact type. Returns default float value if the field is not set. Returns [None] if no default value is found. Returns error if the type doesn't match.
101 ///
102 /// # Safety
103 ///
104 /// The value of the corresponding slot must have type T
get_field_float<T: for<'a> Follow<'a, Inner = T> + Float>( table: &Table, field: &Field, ) -> FlatbufferResult<Option<T>>105 pub unsafe fn get_field_float<T: for<'a> Follow<'a, Inner = T> + Float>(
106     table: &Table,
107     field: &Field,
108 ) -> FlatbufferResult<Option<T>> {
109     if size_of::<T>() != get_type_size(field.type_().base_type()) {
110         return Err(FlatbufferError::FieldTypeMismatch(
111             std::any::type_name::<T>().to_string(),
112             field
113                 .type_()
114                 .base_type()
115                 .variant_name()
116                 .unwrap_or_default()
117                 .to_string(),
118         ));
119     }
120 
121     let default = T::from(field.default_real());
122     Ok(table.get::<T>(field.offset(), default))
123 }
124 
125 /// Gets a String table field given its exact type. Returns empty string if the field is not set. Returns [None] if no default value is found. Returns error if the type size doesn't match.
126 ///
127 /// # Safety
128 ///
129 /// The value of the corresponding slot must have type String
get_field_string<'a>( table: &Table<'a>, field: &Field, ) -> FlatbufferResult<Option<&'a str>>130 pub unsafe fn get_field_string<'a>(
131     table: &Table<'a>,
132     field: &Field,
133 ) -> FlatbufferResult<Option<&'a str>> {
134     if field.type_().base_type() != BaseType::String {
135         return Err(FlatbufferError::FieldTypeMismatch(
136             String::from("String"),
137             field
138                 .type_()
139                 .base_type()
140                 .variant_name()
141                 .unwrap_or_default()
142                 .to_string(),
143         ));
144     }
145 
146     Ok(table.get::<ForwardsUOffset<&'a str>>(field.offset(), Some("")))
147 }
148 
149 /// Gets a [Struct] table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match.
150 ///
151 /// # Safety
152 ///
153 /// The value of the corresponding slot must have type Struct
get_field_struct<'a>( table: &Table<'a>, field: &Field, ) -> FlatbufferResult<Option<Struct<'a>>>154 pub unsafe fn get_field_struct<'a>(
155     table: &Table<'a>,
156     field: &Field,
157 ) -> FlatbufferResult<Option<Struct<'a>>> {
158     // TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
159     // access to the schema to check the is_struct flag.
160     if field.type_().base_type() != BaseType::Obj {
161         return Err(FlatbufferError::FieldTypeMismatch(
162             String::from("Obj"),
163             field
164                 .type_()
165                 .base_type()
166                 .variant_name()
167                 .unwrap_or_default()
168                 .to_string(),
169         ));
170     }
171 
172     Ok(table.get::<Struct>(field.offset(), None))
173 }
174 
175 /// Gets a Vector table field given its exact type. Returns empty vector if the field is not set. Returns error if the type doesn't match.
176 ///
177 /// # Safety
178 ///
179 /// The value of the corresponding slot must have type Vector
get_field_vector<'a, T: Follow<'a, Inner = T>>( table: &Table<'a>, field: &Field, ) -> FlatbufferResult<Option<Vector<'a, T>>>180 pub unsafe fn get_field_vector<'a, T: Follow<'a, Inner = T>>(
181     table: &Table<'a>,
182     field: &Field,
183 ) -> FlatbufferResult<Option<Vector<'a, T>>> {
184     if field.type_().base_type() != BaseType::Vector
185         || core::mem::size_of::<T>() != get_type_size(field.type_().element())
186     {
187         return Err(FlatbufferError::FieldTypeMismatch(
188             std::any::type_name::<T>().to_string(),
189             field
190                 .type_()
191                 .base_type()
192                 .variant_name()
193                 .unwrap_or_default()
194                 .to_string(),
195         ));
196     }
197 
198     Ok(table.get::<ForwardsUOffset<Vector<'a, T>>>(field.offset(), Some(Vector::<T>::default())))
199 }
200 
201 /// Gets a Table table field given its exact type. Returns [None] if the field is not set. Returns error if the type doesn't match.
202 ///
203 /// # Safety
204 ///
205 /// The value of the corresponding slot must have type Table
get_field_table<'a>( table: &Table<'a>, field: &Field, ) -> FlatbufferResult<Option<Table<'a>>>206 pub unsafe fn get_field_table<'a>(
207     table: &Table<'a>,
208     field: &Field,
209 ) -> FlatbufferResult<Option<Table<'a>>> {
210     if field.type_().base_type() != BaseType::Obj {
211         return Err(FlatbufferError::FieldTypeMismatch(
212             String::from("Obj"),
213             field
214                 .type_()
215                 .base_type()
216                 .variant_name()
217                 .unwrap_or_default()
218                 .to_string(),
219         ));
220     }
221 
222     Ok(table.get::<ForwardsUOffset<Table<'a>>>(field.offset(), None))
223 }
224 
225 /// Returns the value of any table field as a 64-bit int, regardless of what type it is. Returns default integer if the field is not set or error if the value cannot be parsed as integer.
226 /// [num_traits](https://docs.rs/num-traits/latest/num_traits/cast/trait.NumCast.html) is used for number casting.
227 ///
228 /// # Safety
229 ///
230 /// [table] must contain recursively valid offsets that match the [field].
get_any_field_integer(table: &Table, field: &Field) -> FlatbufferResult<i64>231 pub unsafe fn get_any_field_integer(table: &Table, field: &Field) -> FlatbufferResult<i64> {
232     if let Some(field_loc) = get_field_loc(table, field) {
233         get_any_value_integer(field.type_().base_type(), table.buf(), field_loc)
234     } else {
235         Ok(field.default_integer())
236     }
237 }
238 
239 /// Returns the value of any table field as a 64-bit floating point, regardless of what type it is. Returns default float if the field is not set or error if the value cannot be parsed as float.
240 ///
241 /// # Safety
242 ///
243 /// [table] must contain recursively valid offsets that match the [field].
get_any_field_float(table: &Table, field: &Field) -> FlatbufferResult<f64>244 pub unsafe fn get_any_field_float(table: &Table, field: &Field) -> FlatbufferResult<f64> {
245     if let Some(field_loc) = get_field_loc(table, field) {
246         get_any_value_float(field.type_().base_type(), table.buf(), field_loc)
247     } else {
248         Ok(field.default_real())
249     }
250 }
251 
252 /// Returns the value of any table field as a string, regardless of what type it is. Returns empty string if the field is not set.
253 ///
254 /// # Safety
255 ///
256 /// [table] must contain recursively valid offsets that match the [field].
get_any_field_string(table: &Table, field: &Field, schema: &Schema) -> String257 pub unsafe fn get_any_field_string(table: &Table, field: &Field, schema: &Schema) -> String {
258     if let Some(field_loc) = get_field_loc(table, field) {
259         get_any_value_string(
260             field.type_().base_type(),
261             table.buf(),
262             field_loc,
263             schema,
264             field.type_().index() as usize,
265         )
266     } else {
267         String::from("")
268     }
269 }
270 
271 /// Gets a [Struct] struct field given its exact type. Returns error if the type doesn't match.
272 ///
273 /// # Safety
274 ///
275 /// The value of the corresponding slot must have type Struct.
get_field_struct_in_struct<'a>( st: &Struct<'a>, field: &Field, ) -> FlatbufferResult<Struct<'a>>276 pub unsafe fn get_field_struct_in_struct<'a>(
277     st: &Struct<'a>,
278     field: &Field,
279 ) -> FlatbufferResult<Struct<'a>> {
280     // TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
281     // access to the schema to check the is_struct flag.
282     if field.type_().base_type() != BaseType::Obj {
283         return Err(FlatbufferError::FieldTypeMismatch(
284             String::from("Obj"),
285             field
286                 .type_()
287                 .base_type()
288                 .variant_name()
289                 .unwrap_or_default()
290                 .to_string(),
291         ));
292     }
293 
294     Ok(st.get::<Struct>(field.offset() as usize))
295 }
296 
297 /// Returns the value of any struct field as a 64-bit int, regardless of what type it is. Returns error if the value cannot be parsed as integer.
298 ///
299 /// # Safety
300 ///
301 /// [st] must contain valid offsets that match the [field].
get_any_field_integer_in_struct(st: &Struct, field: &Field) -> FlatbufferResult<i64>302 pub unsafe fn get_any_field_integer_in_struct(st: &Struct, field: &Field) -> FlatbufferResult<i64> {
303     let field_loc = st.loc() + field.offset() as usize;
304 
305     get_any_value_integer(field.type_().base_type(), st.buf(), field_loc)
306 }
307 
308 /// Returns the value of any struct field as a 64-bit floating point, regardless of what type it is. Returns error if the value cannot be parsed as float.
309 ///
310 /// # Safety
311 ///
312 /// [st] must contain valid offsets that match the [field].
get_any_field_float_in_struct(st: &Struct, field: &Field) -> FlatbufferResult<f64>313 pub unsafe fn get_any_field_float_in_struct(st: &Struct, field: &Field) -> FlatbufferResult<f64> {
314     let field_loc = st.loc() + field.offset() as usize;
315 
316     get_any_value_float(field.type_().base_type(), st.buf(), field_loc)
317 }
318 
319 /// Returns the value of any struct field as a string, regardless of what type it is.
320 ///
321 /// # Safety
322 ///
323 /// [st] must contain valid offsets that match the [field].
get_any_field_string_in_struct( st: &Struct, field: &Field, schema: &Schema, ) -> String324 pub unsafe fn get_any_field_string_in_struct(
325     st: &Struct,
326     field: &Field,
327     schema: &Schema,
328 ) -> String {
329     let field_loc = st.loc() + field.offset() as usize;
330 
331     get_any_value_string(
332         field.type_().base_type(),
333         st.buf(),
334         field_loc,
335         schema,
336         field.type_().index() as usize,
337     )
338 }
339 
340 /// Sets any table field with the value of a 64-bit integer. Returns error if the field is not originally set or is with non-scalar value or the provided value cannot be cast into the field type.
341 ///
342 /// # Safety
343 ///
344 /// [buf] must contain a valid root table and valid offset to it.
set_any_field_integer( buf: &mut [u8], table_loc: usize, field: &Field, v: i64, ) -> FlatbufferResult<()>345 pub unsafe fn set_any_field_integer(
346     buf: &mut [u8],
347     table_loc: usize,
348     field: &Field,
349     v: i64,
350 ) -> FlatbufferResult<()> {
351     let field_type = field.type_().base_type();
352     let table = Table::follow(buf, table_loc);
353 
354     let Some(field_loc) = get_field_loc(&table, field) else {
355         return Err(FlatbufferError::SetValueNotSupported);
356     };
357 
358     if !is_scalar(field_type) {
359         return Err(FlatbufferError::SetValueNotSupported);
360     }
361 
362     set_any_value_integer(field_type, buf, field_loc, v)
363 }
364 
365 /// Sets any table field with the value of a 64-bit floating point. Returns error if the field is not originally set or is with non-scalar value or the provided value cannot be cast into the field type.
366 ///
367 /// # Safety
368 ///
369 /// [buf] must contain a valid root table and valid offset to it.
set_any_field_float( buf: &mut [u8], table_loc: usize, field: &Field, v: f64, ) -> FlatbufferResult<()>370 pub unsafe fn set_any_field_float(
371     buf: &mut [u8],
372     table_loc: usize,
373     field: &Field,
374     v: f64,
375 ) -> FlatbufferResult<()> {
376     let field_type = field.type_().base_type();
377     let table = Table::follow(buf, table_loc);
378 
379     let Some(field_loc) = get_field_loc(&table, field) else {
380         return Err(FlatbufferError::SetValueNotSupported);
381     };
382 
383     if !is_scalar(field_type) {
384         return Err(FlatbufferError::SetValueNotSupported);
385     }
386 
387     set_any_value_float(field_type, buf, field_loc, v)
388 }
389 
390 /// Sets any table field with the value of a string. Returns error if the field is not originally set or is with non-scalar value or the provided value cannot be parsed as the field type.
391 ///
392 /// # Safety
393 ///
394 /// [buf] must contain a valid root table and valid offset to it.
set_any_field_string( buf: &mut [u8], table_loc: usize, field: &Field, v: &str, ) -> FlatbufferResult<()>395 pub unsafe fn set_any_field_string(
396     buf: &mut [u8],
397     table_loc: usize,
398     field: &Field,
399     v: &str,
400 ) -> FlatbufferResult<()> {
401     let field_type = field.type_().base_type();
402     let table = Table::follow(buf, table_loc);
403 
404     let Some(field_loc) = get_field_loc(&table, field) else {
405         return Err(FlatbufferError::SetValueNotSupported);
406     };
407 
408     if !is_scalar(field_type) {
409         return Err(FlatbufferError::SetValueNotSupported);
410     }
411 
412     set_any_value_float(field_type, buf, field_loc, v.parse::<f64>()?)
413 }
414 
415 /// Sets any scalar field given its exact type. Returns error if the field is not originally set or is with non-scalar value.
416 ///
417 /// # Safety
418 ///
419 /// [buf] must contain a valid root table and valid offset to it.
set_field<T: EndianScalar>( buf: &mut [u8], table_loc: usize, field: &Field, v: T, ) -> FlatbufferResult<()>420 pub unsafe fn set_field<T: EndianScalar>(
421     buf: &mut [u8],
422     table_loc: usize,
423     field: &Field,
424     v: T,
425 ) -> FlatbufferResult<()> {
426     let field_type = field.type_().base_type();
427     let table = Table::follow(buf, table_loc);
428 
429     if !is_scalar(field_type) {
430         return Err(FlatbufferError::SetValueNotSupported);
431     }
432 
433     if core::mem::size_of::<T>() != get_type_size(field_type) {
434         return Err(FlatbufferError::FieldTypeMismatch(
435             std::any::type_name::<T>().to_string(),
436             field_type.variant_name().unwrap_or_default().to_string(),
437         ));
438     }
439 
440     let Some(field_loc) = get_field_loc(&table, field) else {
441         return Err(FlatbufferError::SetValueNotSupported);
442     };
443 
444     if buf.len() < field_loc.saturating_add(get_type_size(field_type)) {
445         return Err(FlatbufferError::VerificationError(
446             InvalidFlatbuffer::RangeOutOfBounds {
447                 range: core::ops::Range {
448                     start: field_loc,
449                     end: field_loc.saturating_add(get_type_size(field_type)),
450                 },
451                 error_trace: Default::default(),
452             },
453         ));
454     }
455 
456     // SAFETY: the buffer range was verified above.
457     unsafe { Ok(emplace_scalar::<T>(&mut buf[field_loc..], v)) }
458 }
459 
460 /// Sets a string field to a new value. Returns error if the field is not originally set or is not of string type in which cases the [buf] stays intact. Returns error if the [buf] fails to be updated.
461 ///
462 /// # Safety
463 ///
464 /// [buf] must contain a valid root table and valid offset to it and conform to the [schema].
set_string( buf: &mut Vec<u8>, table_loc: usize, field: &Field, v: &str, schema: &Schema, ) -> FlatbufferResult<()>465 pub unsafe fn set_string(
466     buf: &mut Vec<u8>,
467     table_loc: usize,
468     field: &Field,
469     v: &str,
470     schema: &Schema,
471 ) -> FlatbufferResult<()> {
472     if v.is_empty() {
473         return Ok(());
474     }
475 
476     let field_type = field.type_().base_type();
477     if field_type != BaseType::String {
478         return Err(FlatbufferError::FieldTypeMismatch(
479             String::from("String"),
480             field_type.variant_name().unwrap_or_default().to_string(),
481         ));
482     }
483 
484     let table = Table::follow(buf, table_loc);
485 
486     let Some(field_loc) = get_field_loc(&table, field) else {
487         return Err(FlatbufferError::SetValueNotSupported);
488     };
489 
490     if buf.len() < field_loc + get_type_size(field_type) {
491         return Err(FlatbufferError::VerificationError(
492             InvalidFlatbuffer::RangeOutOfBounds {
493                 range: core::ops::Range {
494                     start: field_loc,
495                     end: field_loc.saturating_add(get_type_size(field_type)),
496                 },
497                 error_trace: Default::default(),
498             },
499         ));
500     }
501 
502     // SAFETY: the buffer range was verified above.
503     let string_loc = unsafe { deref_uoffset(buf, field_loc)? };
504     if buf.len() < string_loc.saturating_add(SIZE_UOFFSET) {
505         return Err(FlatbufferError::VerificationError(
506             InvalidFlatbuffer::RangeOutOfBounds {
507                 range: core::ops::Range {
508                     start: string_loc,
509                     end: string_loc.saturating_add(SIZE_UOFFSET),
510                 },
511                 error_trace: Default::default(),
512             },
513         ));
514     }
515 
516     // SAFETY: the buffer range was verified above.
517     let len_old = unsafe { read_uoffset(buf, string_loc) };
518     if buf.len()
519         < string_loc
520             .saturating_add(SIZE_UOFFSET)
521             .saturating_add(len_old.try_into()?)
522     {
523         return Err(FlatbufferError::VerificationError(
524             InvalidFlatbuffer::RangeOutOfBounds {
525                 range: core::ops::Range {
526                     start: string_loc,
527                     end: string_loc
528                         .saturating_add(SIZE_UOFFSET)
529                         .saturating_add(len_old.try_into()?),
530                 },
531                 error_trace: Default::default(),
532             },
533         ));
534     }
535 
536     let len_new = v.len();
537     let delta = len_new as isize - len_old as isize;
538     let mut bytes_to_insert = v.as_bytes().to_vec();
539 
540     if delta != 0 {
541         // Rounds the delta up to the nearest multiple of the maximum int size to keep the types after the insersion point aligned.
542         // stdint crate defines intmax_t as an alias for c_long; use it directly to avoid extra
543         // dependency.
544         let mask = (size_of::<core::ffi::c_long>() - 1) as isize;
545         let offset = (delta + mask) & !mask;
546         let mut visited_vec = vec![false; buf.len()];
547 
548         if offset != 0 {
549             update_offset(
550                 buf,
551                 table_loc,
552                 &mut visited_vec,
553                 &schema.root_table().unwrap(),
554                 schema,
555                 string_loc,
556                 offset,
557             )?;
558 
559             // Sets the new length.
560             emplace_scalar::<SOffsetT>(
561                 &mut buf[string_loc..string_loc + SIZE_UOFFSET],
562                 len_new.try_into()?,
563             );
564         }
565 
566         // Pads the bytes vector with 0 if `offset` doesn't equal `delta`.
567         bytes_to_insert.resize(bytes_to_insert.len() + (offset - delta) as usize, 0);
568     }
569 
570     // Replaces the data.
571     buf.splice(
572         string_loc + SIZE_SOFFSET..string_loc + SIZE_UOFFSET + usize::try_from(len_old)?,
573         bytes_to_insert,
574     );
575     Ok(())
576 }
577 
578 /// Returns the size of a scalar type in the `BaseType` enum. In the case of structs, returns the size of their offset (`UOffsetT`) in the buffer.
get_type_size(base_type: BaseType) -> usize579 fn get_type_size(base_type: BaseType) -> usize {
580     match base_type {
581         BaseType::UType | BaseType::Bool | BaseType::Byte | BaseType::UByte => 1,
582         BaseType::Short | BaseType::UShort => 2,
583         BaseType::Int
584         | BaseType::UInt
585         | BaseType::Float
586         | BaseType::String
587         | BaseType::Vector
588         | BaseType::Obj
589         | BaseType::Union => 4,
590         BaseType::Long | BaseType::ULong | BaseType::Double | BaseType::Vector64 => 8,
591         _ => 0,
592     }
593 }
594 
595 /// Returns the absolute field location in the buffer and [None] if the field is not populated.
596 ///
597 /// # Safety
598 ///
599 /// [table] must contain a valid vtable.
get_field_loc(table: &Table, field: &Field) -> Option<usize>600 unsafe fn get_field_loc(table: &Table, field: &Field) -> Option<usize> {
601     let field_offset = table.vtable().get(field.offset()) as usize;
602     if field_offset == 0 {
603         return None;
604     }
605 
606     Some(table.loc() + field_offset)
607 }
608 
609 /// Reads value as a 64-bit int from the provided byte slice at the specified location. Returns error if the value cannot be parsed as integer.
610 ///
611 /// # Safety
612 ///
613 /// Caller must ensure `buf.len() >= loc + size_of::<T>()` at all the access layers.
get_any_value_integer( base_type: BaseType, buf: &[u8], loc: usize, ) -> FlatbufferResult<i64>614 unsafe fn get_any_value_integer(
615     base_type: BaseType,
616     buf: &[u8],
617     loc: usize,
618 ) -> FlatbufferResult<i64> {
619     match base_type {
620         BaseType::UType | BaseType::UByte => i64::from_u8(u8::follow(buf, loc)),
621         BaseType::Bool => bool::follow(buf, loc).try_into().ok(),
622         BaseType::Byte => i64::from_i8(i8::follow(buf, loc)),
623         BaseType::Short => i64::from_i16(i16::follow(buf, loc)),
624         BaseType::UShort => i64::from_u16(u16::follow(buf, loc)),
625         BaseType::Int => i64::from_i32(i32::follow(buf, loc)),
626         BaseType::UInt => i64::from_u32(u32::follow(buf, loc)),
627         BaseType::Long => Some(i64::follow(buf, loc)),
628         BaseType::ULong => i64::from_u64(u64::follow(buf, loc)),
629         BaseType::Float => i64::from_f32(f32::follow(buf, loc)),
630         BaseType::Double => i64::from_f64(f64::follow(buf, loc)),
631         BaseType::String => ForwardsUOffset::<&str>::follow(buf, loc)
632             .parse::<i64>()
633             .ok(),
634         _ => None, // Tables & vectors do not make sense.
635     }
636     .ok_or(FlatbufferError::FieldTypeMismatch(
637         String::from("i64"),
638         base_type.variant_name().unwrap_or_default().to_string(),
639     ))
640 }
641 
642 /// Reads value as a 64-bit floating point from the provided byte slice at the specified location. Returns error if the value cannot be parsed as float.
643 ///
644 /// # Safety
645 ///
646 /// Caller must ensure `buf.len() >= loc + size_of::<T>()` at all the access layers.
get_any_value_float( base_type: BaseType, buf: &[u8], loc: usize, ) -> FlatbufferResult<f64>647 unsafe fn get_any_value_float(
648     base_type: BaseType,
649     buf: &[u8],
650     loc: usize,
651 ) -> FlatbufferResult<f64> {
652     match base_type {
653         BaseType::UType | BaseType::UByte => f64::from_u8(u8::follow(buf, loc)),
654         BaseType::Bool => bool::follow(buf, loc).try_into().ok(),
655         BaseType::Byte => f64::from_i8(i8::follow(buf, loc)),
656         BaseType::Short => f64::from_i16(i16::follow(buf, loc)),
657         BaseType::UShort => f64::from_u16(u16::follow(buf, loc)),
658         BaseType::Int => f64::from_i32(i32::follow(buf, loc)),
659         BaseType::UInt => f64::from_u32(u32::follow(buf, loc)),
660         BaseType::Long => f64::from_i64(i64::follow(buf, loc)),
661         BaseType::ULong => f64::from_u64(u64::follow(buf, loc)),
662         BaseType::Float => f64::from_f32(f32::follow(buf, loc)),
663         BaseType::Double => Some(f64::follow(buf, loc)),
664         BaseType::String => ForwardsUOffset::<&str>::follow(buf, loc)
665             .parse::<f64>()
666             .ok(),
667         _ => None,
668     }
669     .ok_or(FlatbufferError::FieldTypeMismatch(
670         String::from("f64"),
671         base_type.variant_name().unwrap_or_default().to_string(),
672     ))
673 }
674 
675 /// Reads value as a string from the provided byte slice at the specified location.
676 ///
677 /// # Safety
678 ///
679 /// Caller must ensure `buf.len() >= loc + size_of::<T>()` at all the access layers.
get_any_value_string( base_type: BaseType, buf: &[u8], loc: usize, schema: &Schema, type_index: usize, ) -> String680 unsafe fn get_any_value_string(
681     base_type: BaseType,
682     buf: &[u8],
683     loc: usize,
684     schema: &Schema,
685     type_index: usize,
686 ) -> String {
687     match base_type {
688         BaseType::Float | BaseType::Double => get_any_value_float(base_type, buf, loc)
689             .unwrap_or_default()
690             .to_string(),
691         BaseType::String => {
692             String::from_utf8_lossy(ForwardsUOffset::<&[u8]>::follow(buf, loc)).to_string()
693         }
694         BaseType::Obj => {
695             // Converts the table to a string. This is mostly for debugging purposes,
696             // and does NOT promise to be JSON compliant.
697             // Also prefixes the type.
698             let object: Object = schema.objects().get(type_index);
699             let mut s = object.name().to_string();
700             s += " { ";
701             if object.is_struct() {
702                 let st: Struct<'_> = Struct::follow(buf, loc);
703                 for field in object.fields() {
704                     let field_value = get_any_field_string_in_struct(&st, &field, schema);
705                     s += field.name();
706                     s += ": ";
707                     s += field_value.as_str();
708                     s += ", ";
709                 }
710             } else {
711                 let table = ForwardsUOffset::<Table>::follow(buf, loc);
712                 for field in object.fields() {
713                     if table.vtable().get(field.offset()) == 0 {
714                         continue;
715                     }
716                     let mut field_value = get_any_field_string(&table, &field, schema);
717                     if field.type_().base_type() == BaseType::String {
718                         // Escape the string
719                         field_value = format!("{:?}", field_value.as_str());
720                     }
721                     s += field.name();
722                     s += ": ";
723                     s += field_value.as_str();
724                     s += ", ";
725                 }
726             }
727             s + "}"
728         }
729         BaseType::Vector => String::from("[(elements)]"), // TODO inherited from C++: implement this as well.
730         BaseType::Union => String::from("(union)"), // TODO inherited from C++: implement this as well.
731         _ => get_any_value_integer(base_type, buf, loc)
732             .unwrap_or_default()
733             .to_string(),
734     }
735 }
736 
737 /// Sets any scalar value with a 64-bit integer. Returns error if the value is not successfully replaced.
set_any_value_integer( base_type: BaseType, buf: &mut [u8], field_loc: usize, v: i64, ) -> FlatbufferResult<()>738 fn set_any_value_integer(
739     base_type: BaseType,
740     buf: &mut [u8],
741     field_loc: usize,
742     v: i64,
743 ) -> FlatbufferResult<()> {
744     if buf.len() < get_type_size(base_type) {
745         return Err(FlatbufferError::VerificationError(
746             InvalidFlatbuffer::RangeOutOfBounds {
747                 range: core::ops::Range {
748                     start: field_loc,
749                     end: field_loc.saturating_add(get_type_size(base_type)),
750                 },
751                 error_trace: Default::default(),
752             },
753         ));
754     }
755     let buf = &mut buf[field_loc..];
756     let type_name = base_type.variant_name().unwrap_or_default().to_string();
757 
758     macro_rules! try_emplace {
759         ($ty:ty, $value:expr) => {
760             if let Ok(v) = TryInto::<$ty>::try_into($value) {
761                 // SAFETY: buffer size is verified at the beginning of this function.
762                 unsafe { Ok(emplace_scalar::<$ty>(buf, v)) }
763             } else {
764                 Err(FlatbufferError::FieldTypeMismatch(
765                     String::from("i64"),
766                     type_name,
767                 ))
768             }
769         };
770     }
771 
772     match base_type {
773         BaseType::UType | BaseType::UByte => {
774             try_emplace!(u8, v)
775         }
776         BaseType::Bool => {
777             // SAFETY: buffer size is verified at the beginning of this function.
778             unsafe { Ok(emplace_scalar::<bool>(buf, v != 0)) }
779         }
780         BaseType::Byte => {
781             try_emplace!(i8, v)
782         }
783         BaseType::Short => {
784             try_emplace!(i16, v)
785         }
786         BaseType::UShort => {
787             try_emplace!(u16, v)
788         }
789         BaseType::Int => {
790             try_emplace!(i32, v)
791         }
792         BaseType::UInt => {
793             try_emplace!(u32, v)
794         }
795         BaseType::Long => {
796             // SAFETY: buffer size is verified at the beginning of this function.
797             unsafe { Ok(emplace_scalar::<i64>(buf, v)) }
798         }
799         BaseType::ULong => {
800             try_emplace!(u64, v)
801         }
802         BaseType::Float => {
803             if let Some(value) = f32::from_i64(v) {
804                 // SAFETY: buffer size is verified at the beginning of this function.
805                 unsafe { Ok(emplace_scalar::<f32>(buf, value)) }
806             } else {
807                 Err(FlatbufferError::FieldTypeMismatch(
808                     String::from("i64"),
809                     type_name,
810                 ))
811             }
812         }
813         BaseType::Double => {
814             if let Some(value) = f64::from_i64(v) {
815                 // SAFETY: buffer size is verified at the beginning of this function.
816                 unsafe { Ok(emplace_scalar::<f64>(buf, value)) }
817             } else {
818                 Err(FlatbufferError::FieldTypeMismatch(
819                     String::from("i64"),
820                     type_name,
821                 ))
822             }
823         }
824         _ => Err(FlatbufferError::SetValueNotSupported),
825     }
826 }
827 
828 /// Sets any scalar value with a 64-bit floating point. Returns error if the value is not successfully replaced.
set_any_value_float( base_type: BaseType, buf: &mut [u8], field_loc: usize, v: f64, ) -> FlatbufferResult<()>829 fn set_any_value_float(
830     base_type: BaseType,
831     buf: &mut [u8],
832     field_loc: usize,
833     v: f64,
834 ) -> FlatbufferResult<()> {
835     if buf.len() < get_type_size(base_type) {
836         return Err(FlatbufferError::VerificationError(
837             InvalidFlatbuffer::RangeOutOfBounds {
838                 range: core::ops::Range {
839                     start: field_loc,
840                     end: field_loc.saturating_add(get_type_size(base_type)),
841                 },
842                 error_trace: Default::default(),
843             },
844         ));
845     }
846     let buf = &mut buf[field_loc..];
847     let type_name = base_type.variant_name().unwrap_or_default().to_string();
848 
849     match base_type {
850         BaseType::UType | BaseType::UByte => {
851             if let Some(value) = u8::from_f64(v) {
852                 // SAFETY: buffer size is verified at the beginning of this function.
853                 unsafe {
854                     return Ok(emplace_scalar::<u8>(buf, value));
855                 }
856             }
857         }
858         BaseType::Bool => {
859             // SAFETY: buffer size is verified at the beginning of this function.
860             unsafe {
861                 return Ok(emplace_scalar::<bool>(buf, v != 0f64));
862             }
863         }
864         BaseType::Byte => {
865             if let Some(value) = i8::from_f64(v) {
866                 // SAFETY: buffer size is verified at the beginning of this function.
867                 unsafe {
868                     return Ok(emplace_scalar::<i8>(buf, value));
869                 }
870             }
871         }
872         BaseType::Short => {
873             if let Some(value) = i16::from_f64(v) {
874                 // SAFETY: buffer size is verified at the beginning of this function.
875                 unsafe {
876                     return Ok(emplace_scalar::<i16>(buf, value));
877                 }
878             }
879         }
880         BaseType::UShort => {
881             if let Some(value) = u16::from_f64(v) {
882                 // SAFETY: buffer size is verified at the beginning of this function.
883                 unsafe {
884                     return Ok(emplace_scalar::<u16>(buf, value));
885                 }
886             }
887         }
888         BaseType::Int => {
889             if let Some(value) = i32::from_f64(v) {
890                 // SAFETY: buffer size is verified at the beginning of this function.
891                 unsafe {
892                     return Ok(emplace_scalar::<i32>(buf, value));
893                 }
894             }
895         }
896         BaseType::UInt => {
897             if let Some(value) = u32::from_f64(v) {
898                 // SAFETY: buffer size is verified at the beginning of this function.
899                 unsafe {
900                     return Ok(emplace_scalar::<u32>(buf, value));
901                 }
902             }
903         }
904         BaseType::Long => {
905             if let Some(value) = i64::from_f64(v) {
906                 // SAFETY: buffer size is verified at the beginning of this function.
907                 unsafe {
908                     return Ok(emplace_scalar::<i64>(buf, value));
909                 }
910             }
911         }
912         BaseType::ULong => {
913             if let Some(value) = u64::from_f64(v) {
914                 // SAFETY: buffer size is verified at the beginning of this function.
915                 unsafe {
916                     return Ok(emplace_scalar::<u64>(buf, value));
917                 }
918             }
919         }
920         BaseType::Float => {
921             if let Some(value) = f32::from_f64(v) {
922                 // Value converted to inf if overflow occurs
923                 if value != f32::INFINITY {
924                     // SAFETY: buffer size is verified at the beginning of this function.
925                     unsafe {
926                         return Ok(emplace_scalar::<f32>(buf, value));
927                     }
928                 }
929             }
930         }
931         BaseType::Double => {
932             // SAFETY: buffer size is verified at the beginning of this function.
933             unsafe {
934                 return Ok(emplace_scalar::<f64>(buf, v));
935             }
936         }
937         _ => return Err(FlatbufferError::SetValueNotSupported),
938     }
939     return Err(FlatbufferError::FieldTypeMismatch(
940         String::from("f64"),
941         type_name,
942     ));
943 }
944 
is_scalar(base_type: BaseType) -> bool945 fn is_scalar(base_type: BaseType) -> bool {
946     return base_type <= BaseType::Double;
947 }
948 
949 /// Iterates through the buffer and updates all the relative offsets affected by the insertion.
950 ///
951 /// # Safety
952 ///
953 /// Caller must ensure [buf] contains valid data that conforms to [schema].
update_offset( buf: &mut [u8], table_loc: usize, updated: &mut [bool], object: &Object, schema: &Schema, insertion_loc: usize, offset: isize, ) -> FlatbufferResult<()>954 unsafe fn update_offset(
955     buf: &mut [u8],
956     table_loc: usize,
957     updated: &mut [bool],
958     object: &Object,
959     schema: &Schema,
960     insertion_loc: usize,
961     offset: isize,
962 ) -> FlatbufferResult<()> {
963     if updated.len() != buf.len() {
964         return Err(FlatbufferError::SetStringPolluted);
965     }
966 
967     if updated[table_loc] {
968         return Ok(());
969     }
970 
971     let slice = &mut buf[table_loc..table_loc + SIZE_SOFFSET];
972     let vtable_offset = isize::try_from(read_scalar::<SOffsetT>(slice))?;
973     let vtable_loc = (isize::try_from(table_loc)? - vtable_offset).try_into()?;
974 
975     if insertion_loc <= table_loc {
976         // Checks if insertion point is between the table and a vtable that
977         // precedes it.
978         if (vtable_loc..table_loc).contains(&insertion_loc) {
979             emplace_scalar::<SOffsetT>(slice, (vtable_offset + offset).try_into()?);
980             updated[table_loc] = true;
981         }
982 
983         // Early out: since all fields inside the table must point forwards in
984         // memory, if the insertion point is before the table we can stop here.
985         return Ok(());
986     }
987 
988     for field in object.fields() {
989         let field_type = field.type_().base_type();
990         if is_scalar(field_type) {
991             continue;
992         }
993 
994         let field_offset = VOffsetT::follow(buf, vtable_loc.saturating_add(field.offset().into()));
995         if field_offset == 0 {
996             continue;
997         }
998 
999         let field_loc = table_loc + usize::from(field_offset);
1000         if updated[field_loc] {
1001             continue;
1002         }
1003 
1004         if field_type == BaseType::Obj
1005             && schema
1006                 .objects()
1007                 .get(field.type_().index().try_into()?)
1008                 .is_struct()
1009         {
1010             continue;
1011         }
1012 
1013         // Updates the relative offset from table to actual data if needed
1014         let slice = &mut buf[field_loc..field_loc + SIZE_UOFFSET];
1015         let field_value_offset = read_scalar::<UOffsetT>(slice);
1016         let field_value_loc = field_loc.saturating_add(field_value_offset.try_into()?);
1017         if (field_loc..field_value_loc).contains(&insertion_loc) {
1018             emplace_scalar::<UOffsetT>(
1019                 slice,
1020                 (isize::try_from(field_value_offset)? + offset).try_into()?,
1021             );
1022             updated[field_loc] = true;
1023         }
1024 
1025         match field_type {
1026             BaseType::Obj => {
1027                 let field_obj = schema.objects().get(field.type_().index().try_into()?);
1028                 update_offset(
1029                     buf,
1030                     field_value_loc,
1031                     updated,
1032                     &field_obj,
1033                     schema,
1034                     insertion_loc,
1035                     offset,
1036                 )?;
1037             }
1038             BaseType::Vector => {
1039                 let elem_type = field.type_().element();
1040                 if elem_type != BaseType::Obj || elem_type != BaseType::String {
1041                     continue;
1042                 }
1043                 if elem_type == BaseType::Obj
1044                     && schema
1045                         .objects()
1046                         .get(field.type_().index().try_into()?)
1047                         .is_struct()
1048                 {
1049                     continue;
1050                 }
1051                 let vec_size = usize::try_from(read_uoffset(buf, field_value_loc))?;
1052                 for index in 0..vec_size {
1053                     let elem_loc = field_value_loc + SIZE_UOFFSET + index * SIZE_UOFFSET;
1054                     if updated[elem_loc] {
1055                         continue;
1056                     }
1057                     let slice = &mut buf[elem_loc..elem_loc + SIZE_UOFFSET];
1058                     let elem_value_offset = read_scalar::<UOffsetT>(slice);
1059                     let elem_value_loc = elem_loc.saturating_add(elem_value_offset.try_into()?);
1060                     if (elem_loc..elem_value_loc).contains(&insertion_loc) {
1061                         emplace_scalar::<UOffsetT>(
1062                             slice,
1063                             (isize::try_from(elem_value_offset)? + offset).try_into()?,
1064                         );
1065                         updated[elem_loc] = true;
1066                     }
1067 
1068                     if elem_type == BaseType::Obj {
1069                         let elem_obj = schema.objects().get(field.type_().index().try_into()?);
1070                         update_offset(
1071                             buf,
1072                             elem_value_loc,
1073                             updated,
1074                             &elem_obj,
1075                             schema,
1076                             insertion_loc,
1077                             offset,
1078                         )?;
1079                     }
1080                 }
1081             }
1082             BaseType::Union => {
1083                 let union_enum = schema.enums().get(field.type_().index().try_into()?);
1084                 let union_type = object
1085                     .fields()
1086                     .lookup_by_key(field.name().to_string() + "_type", |field, key| {
1087                         field.key_compare_with_value(key)
1088                     })
1089                     .unwrap();
1090                 let union_type_loc = vtable_loc.saturating_add(union_type.offset().into());
1091                 let union_type_offset = VOffsetT::follow(buf, union_type_loc);
1092                 let union_type_value =
1093                     u8::follow(buf, table_loc.saturating_add(union_type_offset.into()));
1094                 let union_enum_value = union_enum
1095                     .values()
1096                     .lookup_by_key(union_type_value.into(), |value, key| {
1097                         value.key_compare_with_value(*key)
1098                     })
1099                     .unwrap();
1100                 let union_object = schema
1101                     .objects()
1102                     .get(union_enum_value.union_type().unwrap().index().try_into()?);
1103                 update_offset(
1104                     buf,
1105                     field_value_loc,
1106                     updated,
1107                     &union_object,
1108                     schema,
1109                     insertion_loc,
1110                     offset,
1111                 )?;
1112             }
1113             _ => (),
1114         }
1115     }
1116 
1117     // Checks if the vtable offset points beyond the insertion point.
1118     if (table_loc..vtable_loc).contains(&insertion_loc) {
1119         let slice = &mut buf[table_loc..table_loc + SIZE_SOFFSET];
1120         emplace_scalar::<SOffsetT>(slice, (vtable_offset - offset).try_into()?);
1121         updated[table_loc] = true;
1122     }
1123     Ok(())
1124 }
1125 
1126 /// Returns the absolute location of the data (e.g. string) in the buffer when the field contains relative offset (`UOffsetT`) to the data.
1127 ///
1128 /// # Safety
1129 ///
1130 /// The value of the corresponding slot must have type `UOffsetT`.
deref_uoffset(buf: &[u8], field_loc: usize) -> FlatbufferResult<usize>1131 unsafe fn deref_uoffset(buf: &[u8], field_loc: usize) -> FlatbufferResult<usize> {
1132     Ok(field_loc.saturating_add(read_uoffset(buf, field_loc).try_into()?))
1133 }
1134 
1135 /// Reads the value of `UOffsetT` at the give location.
1136 ///
1137 /// # Safety
1138 ///
1139 /// The value of the corresponding slot must have type `UOffsetT`.
read_uoffset(buf: &[u8], loc: usize) -> UOffsetT1140 unsafe fn read_uoffset(buf: &[u8], loc: usize) -> UOffsetT {
1141     let slice = &buf[loc..loc + SIZE_UOFFSET];
1142     read_scalar::<UOffsetT>(slice)
1143 }
1144