1 // Copyright 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 use super::Error; 16 use crate::{FlexBufferType, Reader, ReaderIterator}; 17 use serde::de::{ 18 DeserializeSeed, Deserializer, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, 19 VariantAccess, Visitor, 20 }; 21 22 /// Errors that may happen when deserializing a flexbuffer with serde. 23 #[derive(Debug, Clone, PartialEq, Eq)] 24 pub enum DeserializationError { 25 Reader(Error), 26 Serde(String), 27 } 28 29 impl std::error::Error for DeserializationError {} 30 impl std::fmt::Display for DeserializationError { fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error>31 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> { 32 match self { 33 Self::Reader(r) => write!(f, "Flexbuffer Read Error: {:?}", r), 34 Self::Serde(s) => write!(f, "Serde Error: {}", s), 35 } 36 } 37 } 38 39 impl serde::de::Error for DeserializationError { custom<T>(msg: T) -> Self where T: std::fmt::Display,40 fn custom<T>(msg: T) -> Self 41 where 42 T: std::fmt::Display, 43 { 44 Self::Serde(format!("{}", msg)) 45 } 46 } 47 48 impl std::convert::From<super::Error> for DeserializationError { from(e: super::Error) -> Self49 fn from(e: super::Error) -> Self { 50 Self::Reader(e) 51 } 52 } 53 54 impl<'de> SeqAccess<'de> for ReaderIterator<&'de [u8]> { 55 type Error = DeserializationError; 56 next_element_seed<T>( &mut self, seed: T, ) -> Result<Option<<T as DeserializeSeed<'de>>::Value>, Self::Error> where T: DeserializeSeed<'de>,57 fn next_element_seed<T>( 58 &mut self, 59 seed: T, 60 ) -> Result<Option<<T as DeserializeSeed<'de>>::Value>, Self::Error> 61 where 62 T: DeserializeSeed<'de>, 63 { 64 if let Some(elem) = self.next() { 65 seed.deserialize(elem).map(Some) 66 } else { 67 Ok(None) 68 } 69 } 70 size_hint(&self) -> Option<usize>71 fn size_hint(&self) -> Option<usize> { 72 Some(self.len()) 73 } 74 } 75 76 struct EnumReader<'de> { 77 variant: &'de str, 78 value: Option<Reader<&'de [u8]>>, 79 } 80 81 impl<'de> EnumAccess<'de> for EnumReader<'de> { 82 type Error = DeserializationError; 83 type Variant = Reader<&'de [u8]>; 84 variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> where V: DeserializeSeed<'de>,85 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error> 86 where 87 V: DeserializeSeed<'de>, 88 { 89 seed.deserialize(self.variant.into_deserializer()) 90 .map(|v| (v, self.value.unwrap_or_default())) 91 } 92 } 93 94 struct MapAccessor<'de> { 95 keys: ReaderIterator<&'de [u8]>, 96 vals: ReaderIterator<&'de [u8]>, 97 } 98 99 impl<'de> MapAccess<'de> for MapAccessor<'de> { 100 type Error = DeserializationError; 101 next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> where K: DeserializeSeed<'de>,102 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error> 103 where 104 K: DeserializeSeed<'de>, 105 { 106 if let Some(k) = self.keys.next() { 107 seed.deserialize(k).map(Some) 108 } else { 109 Ok(None) 110 } 111 } 112 next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> where V: DeserializeSeed<'de>,113 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error> 114 where 115 V: DeserializeSeed<'de>, 116 { 117 let val = self.vals.next().ok_or(Error::IndexOutOfBounds)?; 118 seed.deserialize(val) 119 } 120 } 121 122 impl<'de> VariantAccess<'de> for Reader<&'de [u8]> { 123 type Error = DeserializationError; 124 unit_variant(self) -> Result<(), Self::Error>125 fn unit_variant(self) -> Result<(), Self::Error> { 126 Ok(()) 127 } 128 newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> where T: DeserializeSeed<'de>,129 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error> 130 where 131 T: DeserializeSeed<'de>, 132 { 133 seed.deserialize(self) 134 } 135 136 // Tuple variants have an internally tagged representation. They are vectors where Index 0 is 137 // the discriminant and index N is field N-1. tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,138 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error> 139 where 140 V: Visitor<'de>, 141 { 142 visitor.visit_seq(self.as_vector().iter()) 143 } 144 145 // Struct variants have an internally tagged representation. They are vectors where Index 0 is 146 // the discriminant and index N is field N-1. struct_variant<V>( self, _fields: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,147 fn struct_variant<V>( 148 self, 149 _fields: &'static [&'static str], 150 visitor: V, 151 ) -> Result<V::Value, Self::Error> 152 where 153 V: Visitor<'de>, 154 { 155 let m = self.get_map()?; 156 visitor.visit_map(MapAccessor { 157 keys: m.keys_vector().iter(), 158 vals: m.iter_values(), 159 }) 160 } 161 } 162 163 impl<'de> Deserializer<'de> for Reader<&'de [u8]> { 164 type Error = DeserializationError; is_human_readable(&self) -> bool165 fn is_human_readable(&self) -> bool { 166 cfg!(deserialize_human_readable) 167 } 168 deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,169 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> 170 where 171 V: Visitor<'de>, 172 { 173 use crate::BitWidth::*; 174 use crate::FlexBufferType::*; 175 match (self.flexbuffer_type(), self.bitwidth()) { 176 (Bool, _) => visitor.visit_bool(self.as_bool()), 177 (UInt, W8) => visitor.visit_u8(self.as_u8()), 178 (UInt, W16) => visitor.visit_u16(self.as_u16()), 179 (UInt, W32) => visitor.visit_u32(self.as_u32()), 180 (UInt, W64) => visitor.visit_u64(self.as_u64()), 181 (Int, W8) => visitor.visit_i8(self.as_i8()), 182 (Int, W16) => visitor.visit_i16(self.as_i16()), 183 (Int, W32) => visitor.visit_i32(self.as_i32()), 184 (Int, W64) => visitor.visit_i64(self.as_i64()), 185 (Float, W32) => visitor.visit_f32(self.as_f32()), 186 (Float, W64) => visitor.visit_f64(self.as_f64()), 187 (Float, _) => Err(Error::InvalidPackedType.into()), // f8 and f16 are not supported. 188 (Null, _) => visitor.visit_unit(), 189 (String, _) | (Key, _) => visitor.visit_borrowed_str(self.as_str()), 190 (Blob, _) => visitor.visit_borrowed_bytes(self.get_blob()?.0), 191 (Map, _) => { 192 let m = self.get_map()?; 193 visitor.visit_map(MapAccessor { 194 keys: m.keys_vector().iter(), 195 vals: m.iter_values(), 196 }) 197 } 198 (ty, _) if ty.is_vector() => visitor.visit_seq(self.as_vector().iter()), 199 (ty, bw) => unreachable!("TODO deserialize_any {:?} {:?}.", ty, bw), 200 } 201 } 202 203 serde::forward_to_deserialize_any! { 204 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 str unit unit_struct bytes 205 ignored_any map identifier struct tuple tuple_struct seq string 206 } 207 deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,208 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error> 209 where 210 V: Visitor<'de>, 211 { 212 visitor.visit_char(self.as_u8() as char) 213 } 214 deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,215 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error> 216 where 217 V: Visitor<'de>, 218 { 219 visitor.visit_byte_buf(self.get_blob()?.0.to_vec()) 220 } 221 deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> where V: Visitor<'de>,222 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error> 223 where 224 V: Visitor<'de>, 225 { 226 if self.flexbuffer_type() == FlexBufferType::Null { 227 visitor.visit_none() 228 } else { 229 visitor.visit_some(self) 230 } 231 } 232 deserialize_newtype_struct<V>( self, _name: &'static str, visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,233 fn deserialize_newtype_struct<V>( 234 self, 235 _name: &'static str, 236 visitor: V, 237 ) -> Result<V::Value, Self::Error> 238 where 239 V: Visitor<'de>, 240 { 241 visitor.visit_newtype_struct(self) 242 } 243 deserialize_enum<V>( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result<V::Value, Self::Error> where V: Visitor<'de>,244 fn deserialize_enum<V>( 245 self, 246 _name: &'static str, 247 _variants: &'static [&'static str], 248 visitor: V, 249 ) -> Result<V::Value, Self::Error> 250 where 251 V: Visitor<'de>, 252 { 253 let (variant, value) = match self.fxb_type { 254 FlexBufferType::String => (self.as_str(), None), 255 FlexBufferType::Map => { 256 let m = self.get_map()?; 257 let variant = m.keys_vector().idx(0).get_key()?; 258 let value = Some(m.idx(0)); 259 (variant, value) 260 } 261 _ => { 262 return Err(Error::UnexpectedFlexbufferType { 263 expected: FlexBufferType::Map, 264 actual: self.fxb_type, 265 } 266 .into()); 267 } 268 }; 269 visitor.visit_enum(EnumReader { variant, value }) 270 } 271 } 272