• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use super::Error;
16 use crate::{FlexBufferType, Reader, ReaderIterator};
17 use serde::de::{
18     DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess,
19     VariantAccess, Visitor,
20 };
21 
22 /// Errors that may happen when deserializing a flexbuffer with serde.
23 #[derive(Debug, Clone, PartialEq, Eq)]
24 pub enum DeserializationError {
25     Reader(Error),
26     Serde(String),
27 }
28 
29 impl std::error::Error for DeserializationError {}
30 impl std::fmt::Display for DeserializationError {
fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error>31     fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
32         match self {
33             Self::Reader(r) => write!(f, "Flexbuffer Read Error: {:?}", r),
34             Self::Serde(s) => write!(f, "Serde Error: {}", s),
35         }
36     }
37 }
38 
39 impl serde::de::Error for DeserializationError {
custom<T>(msg: T) -> Self where T: std::fmt::Display,40     fn custom<T>(msg: T) -> Self
41     where
42         T: std::fmt::Display,
43     {
44         Self::Serde(format!("{}", msg))
45     }
46 }
47 
48 impl std::convert::From<super::Error> for DeserializationError {
from(e: super::Error) -> Self49     fn from(e: super::Error) -> Self {
50         Self::Reader(e)
51     }
52 }
53 
54 impl<'de> SeqAccess<'de> for ReaderIterator<&'de [u8]> {
55     type Error = DeserializationError;
56 
next_element_seed<T>( &mut self, seed: T, ) -> Result<Option<<T as DeserializeSeed<'de>>::Value>, Self::Error> where T: DeserializeSeed<'de>,57     fn next_element_seed<T>(
58         &mut self,
59         seed: T,
60     ) -> Result<Option<<T as DeserializeSeed<'de>>::Value>, Self::Error>
61     where
62         T: DeserializeSeed<'de>,
63     {
64         if let Some(elem) = self.next() {
65             seed.deserialize(elem).map(Some)
66         } else {
67             Ok(None)
68         }
69     }
70 
size_hint(&self) -> Option<usize>71     fn size_hint(&self) -> Option<usize> {
72         Some(self.len())
73     }
74 }
75 
76 struct EnumReader<'de> {
77     variant: &'de str,
78     value: Option<Reader<&'de [u8]>>,
79 }
80 
81 impl<'de> EnumAccess<'de> for EnumReader<'de> {
82     type Error = DeserializationError;
83     type Variant = Reader<&'de [u8]>;
84 
variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where V: DeserializeSeed<'de>,85     fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
86     where
87         V: DeserializeSeed<'de>,
88     {
89         seed.deserialize(self.variant.into_deserializer())
90             .map(|v| (v, self.value.unwrap_or_default()))
91     }
92 }
93 
94 struct MapAccessor<'de> {
95     keys: ReaderIterator<&'de [u8]>,
96     vals: ReaderIterator<&'de [u8]>,
97 }
98 
99 impl<'de> MapAccess<'de> for MapAccessor<'de> {
100     type Error = DeserializationError;
101 
next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> where K: DeserializeSeed<'de>,102     fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
103     where
104         K: DeserializeSeed<'de>,
105     {
106         if let Some(k) = self.keys.next() {
107             seed.deserialize(k).map(Some)
108         } else {
109             Ok(None)
110         }
111     }
112 
next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> where V: DeserializeSeed<'de>,113     fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
114     where
115         V: DeserializeSeed<'de>,
116     {
117         let val = self.vals.next().ok_or(Error::IndexOutOfBounds)?;
118         seed.deserialize(val)
119     }
120 }
121 
122 impl<'de> VariantAccess<'de> for Reader<&'de [u8]> {
123     type Error = DeserializationError;
124 
unit_variant(self) -> Result<(), Self::Error>125     fn unit_variant(self) -> Result<(), Self::Error> {
126         Ok(())
127     }
128 
newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> where T: DeserializeSeed<'de>,129     fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
130     where
131         T: DeserializeSeed<'de>,
132     {
133         seed.deserialize(self)
134     }
135 
136     // Tuple variants have an internally tagged representation. They are vectors where Index 0 is
137     // the discriminant and index N is field N-1.
tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,138     fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
139     where
140         V: Visitor<'de>,
141     {
142         visitor.visit_seq(self.as_vector().iter())
143     }
144 
145     // Struct variants have an internally tagged representation. They are vectors where Index 0 is
146     // the discriminant and index N is field N-1.
struct_variant<V>( self, _fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,147     fn struct_variant<V>(
148         self,
149         _fields: &'static [&'static str],
150         visitor: V,
151     ) -> Result<V::Value, Self::Error>
152     where
153         V: Visitor<'de>,
154     {
155         let m = self.get_map()?;
156         visitor.visit_map(MapAccessor {
157             keys: m.keys_vector().iter(),
158             vals: m.iter_values(),
159         })
160     }
161 }
162 
163 impl<'de> Deserializer<'de> for Reader<&'de [u8]> {
164     type Error = DeserializationError;
is_human_readable(&self) -> bool165     fn is_human_readable(&self) -> bool {
166         cfg!(deserialize_human_readable)
167     }
168 
deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,169     fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
170     where
171         V: Visitor<'de>,
172     {
173         use crate::BitWidth::*;
174         use crate::FlexBufferType::*;
175         match (self.flexbuffer_type(), self.bitwidth()) {
176             (Bool, _) => visitor.visit_bool(self.as_bool()),
177             (UInt, W8) => visitor.visit_u8(self.as_u8()),
178             (UInt, W16) => visitor.visit_u16(self.as_u16()),
179             (UInt, W32) => visitor.visit_u32(self.as_u32()),
180             (UInt, W64) => visitor.visit_u64(self.as_u64()),
181             (Int, W8) => visitor.visit_i8(self.as_i8()),
182             (Int, W16) => visitor.visit_i16(self.as_i16()),
183             (Int, W32) => visitor.visit_i32(self.as_i32()),
184             (Int, W64) => visitor.visit_i64(self.as_i64()),
185             (Float, W32) => visitor.visit_f32(self.as_f32()),
186             (Float, W64) => visitor.visit_f64(self.as_f64()),
187             (Float, _) => Err(Error::InvalidPackedType.into()), // f8 and f16 are not supported.
188             (Null, _) => visitor.visit_unit(),
189             (String, _) | (Key, _) => visitor.visit_borrowed_str(self.as_str()),
190             (Blob, _) => visitor.visit_borrowed_bytes(self.get_blob()?.0),
191             (Map, _) => {
192                 let m = self.get_map()?;
193                 visitor.visit_map(MapAccessor {
194                     keys: m.keys_vector().iter(),
195                     vals: m.iter_values(),
196                 })
197             }
198             (ty, _) if ty.is_vector() => visitor.visit_seq(self.as_vector().iter()),
199             (ty, bw) => unreachable!("TODO deserialize_any {:?} {:?}.", ty, bw),
200         }
201     }
202 
203     serde::forward_to_deserialize_any! {
204         bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str unit unit_struct bytes
205         ignored_any map identifier struct tuple tuple_struct seq string
206     }
207 
deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,208     fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209     where
210         V: Visitor<'de>,
211     {
212         visitor.visit_char(self.as_u8() as char)
213     }
214 
deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,215     fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216     where
217         V: Visitor<'de>,
218     {
219         visitor.visit_byte_buf(self.get_blob()?.0.to_vec())
220     }
221 
deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,222     fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223     where
224         V: Visitor<'de>,
225     {
226         if self.flexbuffer_type() == FlexBufferType::Null {
227             visitor.visit_none()
228         } else {
229             visitor.visit_some(self)
230         }
231     }
232 
deserialize_newtype_struct<V>( self, _name: &'static str, visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,233     fn deserialize_newtype_struct<V>(
234         self,
235         _name: &'static str,
236         visitor: V,
237     ) -> Result<V::Value, Self::Error>
238     where
239         V: Visitor<'de>,
240     {
241         visitor.visit_newtype_struct(self)
242     }
243 
deserialize_enum<V>( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,244     fn deserialize_enum<V>(
245         self,
246         _name: &'static str,
247         _variants: &'static [&'static str],
248         visitor: V,
249     ) -> Result<V::Value, Self::Error>
250     where
251         V: Visitor<'de>,
252     {
253         let (variant, value) = match self.fxb_type {
254             FlexBufferType::String => (self.as_str(), None),
255             FlexBufferType::Map => {
256                 let m = self.get_map()?;
257                 let variant = m.keys_vector().idx(0).get_key()?;
258                 let value = Some(m.idx(0));
259                 (variant, value)
260             }
261             _ => {
262                 return Err(Error::UnexpectedFlexbufferType {
263                     expected: FlexBufferType::Map,
264                     actual: self.fxb_type,
265                 }
266                 .into());
267             }
268         };
269         visitor.visit_enum(EnumReader { variant, value })
270     }
271 }
272