• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
6 //!
7 //! # Example
8 //!
9 //! ```
10 //! use serde_json::to_string;
11 //! use base::{
12 //!     FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
13 //!     deserialize_with_descriptors,
14 //! };
15 //! use tempfile::tempfile;
16 //!
17 //! let tmp_f = tempfile().unwrap();
18 //!
19 //! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
20 //! let data = FileSerdeWrapper(tmp_f);
21 //!
22 //! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
23 //! let data_wrapper = SerializeDescriptors::new(&data);
24 //!
25 //! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
26 //! // as the data structures are serialized by the serializer.
27 //! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
28 //!
29 //! // If data_wrapper contains any side channel descriptor refs
30 //! // (it contains tmp_f in this case), we can retrieve the actual descriptors
31 //! // from the side channel using into_descriptors().
32 //! let out_descriptors = data_wrapper.into_descriptors();
33 //!
34 //! // When sending out_json over some transport, also send out_descriptors.
35 //!
36 //! // For this example, we aren't really transporting data across the process, but we do need to
37 //! // convert the descriptor type.
38 //! let mut safe_descriptors = out_descriptors
39 //!     .iter()
40 //!     .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
41 //!     .collect();
42 //! std::mem::forget(data); // Prevent double drop of tmp_f.
43 //!
44 //! // The deserialize_with_descriptors function is used give the descriptor deserializers access
45 //! // to side channel descriptors.
46 //! let res: FileSerdeWrapper =
47 //!     deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors)
48 //!        .expect("failed to deserialize");
49 //! ```
50 
51 use std::cell::Cell;
52 use std::cell::RefCell;
53 use std::convert::TryInto;
54 use std::fmt;
55 use std::fs::File;
56 use std::ops::Deref;
57 use std::ops::DerefMut;
58 use std::panic::catch_unwind;
59 use std::panic::resume_unwind;
60 use std::panic::AssertUnwindSafe;
61 
62 use serde::de;
63 use serde::de::Error;
64 use serde::de::Visitor;
65 use serde::ser;
66 use serde::Deserialize;
67 use serde::Deserializer;
68 use serde::Serialize;
69 use serde::Serializer;
70 
71 use super::RawDescriptor;
72 use crate::descriptor::SafeDescriptor;
73 
74 thread_local! {
75     static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
76 }
77 
78 /// Initializes the thread local storage for descriptor serialization. Fails if it was already
79 /// initialized without an intervening `take_descriptor_dst` on this thread.
init_descriptor_dst() -> Result<(), &'static str>80 fn init_descriptor_dst() -> Result<(), &'static str> {
81     DESCRIPTOR_DST.with(|d| {
82         let mut descriptors = d.borrow_mut();
83         if descriptors.is_some() {
84             return Err(
85                 "attempt to initialize descriptor destination that was already initialized",
86             );
87         }
88         *descriptors = Some(Default::default());
89         Ok(())
90     })
91 }
92 
93 /// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
94 /// to `init_descriptor_dst` on this thread.
take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str>95 fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
96     match DESCRIPTOR_DST.with(|d| d.replace(None)) {
97         Some(d) => Ok(d),
98         None => Err("attempt to take descriptor destination before it was initialized"),
99     }
100 }
101 
102 /// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
103 /// the descriptor was pushed.
104 //
105 /// Returns Err if the thread local destination was not already initialized.
push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str>106 fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
107     DESCRIPTOR_DST.with(|d| {
108         d.borrow_mut()
109             .as_mut()
110             .ok_or("attempt to serialize descriptor without descriptor destination")
111             .map(|descriptors| {
112                 let index = descriptors.len();
113                 descriptors.push(rd);
114                 index
115             })
116     })
117 }
118 
119 /// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
120 ///
121 /// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
122 ///
123 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
124 /// "...")]` attribute which will make use of this function.
serialize_descriptor<S: Serializer>( rd: &RawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>125 pub fn serialize_descriptor<S: Serializer>(
126     rd: &RawDescriptor,
127     se: S,
128 ) -> std::result::Result<S::Ok, S::Error> {
129     let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
130     se.serialize_u32(
131         index
132             .try_into()
133             .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
134     )
135 }
136 
137 /// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
138 /// given to an ordinary `Serializer`.
139 ///
140 /// This is the corresponding type to use for serialization before using
141 /// `deserialize_with_descriptors`.
142 ///
143 /// # Examples
144 ///
145 /// ```
146 /// use serde_json::to_string;
147 /// use base::platform::{FileSerdeWrapper, SerializeDescriptors};
148 /// use tempfile::tempfile;
149 ///
150 /// let tmp_f = tempfile().unwrap();
151 /// let data = FileSerdeWrapper(tmp_f);
152 /// let data_wrapper = SerializeDescriptors::new(&data);
153 ///
154 /// // Serializes `v` as normal...
155 /// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
156 /// // If `serialize_descriptor` was called, we can capture the descriptors from here.
157 /// let out_descriptors = data_wrapper.into_descriptors();
158 /// ```
159 pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
160 
161 impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
new(inner: &'a T) -> Self162     pub fn new(inner: &'a T) -> Self {
163         Self(inner, Default::default())
164     }
165 
into_descriptors(self) -> Vec<RawDescriptor>166     pub fn into_descriptors(self) -> Vec<RawDescriptor> {
167         self.1.into_inner()
168     }
169 }
170 
171 impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,172     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
173     where
174         S: Serializer,
175     {
176         init_descriptor_dst().map_err(ser::Error::custom)?;
177 
178         // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
179         // take_descriptor_dst afterwards.
180         let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
181         self.1.set(take_descriptor_dst().unwrap());
182         match res {
183             Ok(r) => r,
184             Err(e) => resume_unwind(e),
185         }
186     }
187 }
188 
189 thread_local! {
190     static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
191 }
192 
193 /// Sets the thread local storage of descriptors for deserialization. Fails if this was already
194 /// called without a call to `take_descriptor_src` on this thread.
195 ///
196 /// This is given as a collection of `Option` so that unused descriptors can be returned.
set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str>197 fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
198     DESCRIPTOR_SRC.with(|d| {
199         let mut src = d.borrow_mut();
200         if src.is_some() {
201             return Err("attempt to set descriptor source that was already set");
202         }
203         *src = Some(descriptors);
204         Ok(())
205     })
206 }
207 
208 /// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
209 /// already taken or never set with `set_descriptor_src`.
210 ///
211 /// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
212 /// were unused.
take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str>213 fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
214     DESCRIPTOR_SRC.with(|d| {
215         d.replace(None)
216             .ok_or("attempt to take descriptor source which was never set")
217     })
218 }
219 
220 /// Takes a descriptor at the given index from the thread local source of descriptors.
221 //
222 /// Returns None if the thread local source was not already initialized.
take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str>223 fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
224     DESCRIPTOR_SRC.with(|d| {
225         d.borrow_mut()
226             .as_mut()
227             .ok_or("attempt to deserialize descriptor without descriptor source")?
228             .get_mut(index)
229             .ok_or("attempt to deserialize out of bounds descriptor")?
230             .take()
231             .ok_or("attempt to deserialize descriptor that was already taken")
232     })
233 }
234 
235 /// Deserializes a descriptor provided via `deserialize_with_descriptors`.
236 ///
237 /// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
238 ///
239 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
240 /// "...")]` attribute which will make use of this function.
deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error> where D: Deserializer<'de>,241 pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
242 where
243     D: Deserializer<'de>,
244 {
245     struct DescriptorVisitor;
246 
247     impl<'de> Visitor<'de> for DescriptorVisitor {
248         type Value = u32;
249 
250         fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
251             formatter.write_str("an integer which fits into a u32")
252         }
253 
254         fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
255             Ok(value as _)
256         }
257 
258         fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
259             Ok(value as _)
260         }
261 
262         fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
263             Ok(value)
264         }
265 
266         fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
267             value.try_into().map_err(E::custom)
268         }
269 
270         fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
271             value.try_into().map_err(E::custom)
272         }
273 
274         fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
275             value.try_into().map_err(E::custom)
276         }
277 
278         fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
279             value.try_into().map_err(E::custom)
280         }
281 
282         fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
283             value.try_into().map_err(E::custom)
284         }
285 
286         fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
287             value.try_into().map_err(E::custom)
288         }
289 
290         fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
291             value.try_into().map_err(E::custom)
292         }
293     }
294 
295     let index = de.deserialize_u32(DescriptorVisitor)? as usize;
296     take_descriptor(index).map_err(D::Error::custom)
297 }
298 
299 /// Allows the use of any serde deserializer within a closure while providing access to the a set of
300 /// descriptors for use in `deserialize_descriptor`.
301 ///
302 /// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
303 ///
304 /// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
305 /// error.
deserialize_with_descriptors<F, T, E>( f: F, descriptors: &mut Vec<Option<SafeDescriptor>>, ) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: de::Error,306 pub fn deserialize_with_descriptors<F, T, E>(
307     f: F,
308     descriptors: &mut Vec<Option<SafeDescriptor>>,
309 ) -> Result<T, E>
310 where
311     F: FnOnce() -> Result<T, E>,
312     E: de::Error,
313 {
314     let swap_descriptors = std::mem::take(descriptors);
315     set_descriptor_src(swap_descriptors).map_err(E::custom)?;
316 
317     // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
318     // take_descriptor_src afterwards.
319     let res = catch_unwind(AssertUnwindSafe(f));
320 
321     // unwrap is used because set_descriptor_src is always called before this, so it should never
322     // panic.
323     *descriptors = take_descriptor_src().unwrap();
324 
325     match res {
326         Ok(r) => r,
327         Err(e) => resume_unwind(e),
328     }
329 }
330 
331 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
332 /// attribute. It only works with fields with `RawDescriptor` type.
333 ///
334 /// # Examples
335 ///
336 /// ```
337 /// use serde::{Deserialize, Serialize};
338 /// use base::platform::RawDescriptor;
339 ///
340 /// #[derive(Serialize, Deserialize)]
341 /// struct RawContainer {
342 ///     #[serde(with = "base::platform::with_raw_descriptor")]
343 ///     rd: RawDescriptor,
344 /// }
345 /// ```
346 pub mod with_raw_descriptor {
347     use serde::Deserializer;
348 
349     use super::super::RawDescriptor;
350     pub use super::serialize_descriptor as serialize;
351     use crate::descriptor::IntoRawDescriptor;
352 
deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error> where D: Deserializer<'de>,353     pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
354     where
355         D: Deserializer<'de>,
356     {
357         super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
358     }
359 }
360 
361 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
362 /// attribute.
363 ///
364 /// # Examples
365 ///
366 /// ```
367 /// use std::fs::File;
368 /// use serde::{Deserialize, Serialize};
369 /// use base::platform::RawDescriptor;
370 ///
371 /// #[derive(Serialize, Deserialize)]
372 /// struct FileContainer {
373 ///     #[serde(with = "base::platform::with_as_descriptor")]
374 ///     file: File,
375 /// }
376 /// ```
377 pub mod with_as_descriptor {
378     use serde::Deserializer;
379     use serde::Serializer;
380 
381     use crate::descriptor::AsRawDescriptor;
382     use crate::descriptor::FromRawDescriptor;
383     use crate::descriptor::IntoRawDescriptor;
384 
serialize<S: Serializer>( rd: &dyn AsRawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>385     pub fn serialize<S: Serializer>(
386         rd: &dyn AsRawDescriptor,
387         se: S,
388     ) -> std::result::Result<S::Ok, S::Error> {
389         super::serialize_descriptor(&rd.as_raw_descriptor(), se)
390     }
391 
deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error> where D: Deserializer<'de>, T: FromRawDescriptor,392     pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
393     where
394         D: Deserializer<'de>,
395         T: FromRawDescriptor,
396     {
397         super::deserialize_descriptor(de)
398             .map(IntoRawDescriptor::into_raw_descriptor)
399             .map(|rd| unsafe { T::from_raw_descriptor(rd) })
400     }
401 }
402 
403 /// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
404 /// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
405 /// `Option<File>`.
406 #[derive(Serialize, Deserialize)]
407 #[serde(transparent)]
408 pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
409 
410 impl fmt::Debug for FileSerdeWrapper {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result411     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
412         self.0.fmt(f)
413     }
414 }
415 
416 impl From<File> for FileSerdeWrapper {
from(file: File) -> Self417     fn from(file: File) -> Self {
418         FileSerdeWrapper(file)
419     }
420 }
421 
422 impl From<FileSerdeWrapper> for File {
from(f: FileSerdeWrapper) -> File423     fn from(f: FileSerdeWrapper) -> File {
424         f.0
425     }
426 }
427 
428 impl Deref for FileSerdeWrapper {
429     type Target = File;
deref(&self) -> &Self::Target430     fn deref(&self) -> &Self::Target {
431         &self.0
432     }
433 }
434 
435 impl DerefMut for FileSerdeWrapper {
deref_mut(&mut self) -> &mut Self::Target436     fn deref_mut(&mut self) -> &mut Self::Target {
437         &mut self.0
438     }
439 }
440 
441 #[cfg(test)]
442 mod tests {
443     use std::collections::HashMap;
444     use std::fs::File;
445     use std::mem::ManuallyDrop;
446 
447     use serde::de::DeserializeOwned;
448     use serde::Deserialize;
449     use serde::Serialize;
450     use tempfile::tempfile;
451 
452     use super::super::deserialize_with_descriptors;
453     use super::super::with_as_descriptor;
454     use super::super::with_raw_descriptor;
455     use super::super::AsRawDescriptor;
456     use super::super::FileSerdeWrapper;
457     use super::super::FromRawDescriptor;
458     use super::super::RawDescriptor;
459     use super::super::SafeDescriptor;
460     use super::super::SerializeDescriptors;
461 
deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T462     fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
463         let mut safe_descriptors = descriptors
464             .iter()
465             .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
466             .collect();
467 
468         let res =
469             deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors)
470                 .unwrap();
471 
472         assert!(safe_descriptors.iter().all(|v| v.is_none()));
473 
474         res
475     }
476 
477     #[test]
raw()478     fn raw() {
479         #[derive(Serialize, Deserialize, PartialEq, Debug)]
480         struct RawContainer {
481             #[serde(with = "with_raw_descriptor")]
482             rd: RawDescriptor,
483         }
484         // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
485         // descriptors for this test.
486         let fake_rd = 5_123_457_i32;
487         let v = RawContainer {
488             rd: fake_rd as RawDescriptor,
489         };
490         let v_serialize = SerializeDescriptors::new(&v);
491         let json = serde_json::to_string(&v_serialize).unwrap();
492         let descriptors = v_serialize.into_descriptors();
493         let res = deserialize(&json, &descriptors);
494         assert_eq!(v, res);
495     }
496 
497     #[test]
file()498     fn file() {
499         #[derive(Serialize, Deserialize)]
500         struct FileContainer {
501             #[serde(with = "with_as_descriptor")]
502             file: File,
503         }
504 
505         let v = FileContainer {
506             file: tempfile().unwrap(),
507         };
508         let v_serialize = SerializeDescriptors::new(&v);
509         let json = serde_json::to_string(&v_serialize).unwrap();
510         let descriptors = v_serialize.into_descriptors();
511         let v = ManuallyDrop::new(v);
512         let res: FileContainer = deserialize(&json, &descriptors);
513         assert_eq!(v.file.as_raw_descriptor(), res.file.as_raw_descriptor());
514     }
515 
516     #[test]
option()517     fn option() {
518         #[derive(Serialize, Deserialize)]
519         struct TestOption {
520             a: Option<FileSerdeWrapper>,
521             b: Option<FileSerdeWrapper>,
522         }
523 
524         let v = TestOption {
525             a: None,
526             b: Some(tempfile().unwrap().into()),
527         };
528         let v_serialize = SerializeDescriptors::new(&v);
529         let json = serde_json::to_string(&v_serialize).unwrap();
530         let descriptors = v_serialize.into_descriptors();
531         let v = ManuallyDrop::new(v);
532         let res: TestOption = deserialize(&json, &descriptors);
533         assert!(res.a.is_none());
534         assert!(res.b.is_some());
535         assert_eq!(
536             v.b.as_ref().unwrap().as_raw_descriptor(),
537             res.b.unwrap().as_raw_descriptor()
538         );
539     }
540 
541     #[test]
map()542     fn map() {
543         let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
544         v.insert("a".into(), tempfile().unwrap().into());
545         v.insert("b".into(), tempfile().unwrap().into());
546         v.insert("c".into(), tempfile().unwrap().into());
547         let v_serialize = SerializeDescriptors::new(&v);
548         let json = serde_json::to_string(&v_serialize).unwrap();
549         let descriptors = v_serialize.into_descriptors();
550         // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
551         // done this way to prevent a double close of the files (which should reside in `res`)
552         // without triggering the leak sanitizer on `v`'s HashMap heap memory.
553         let v: HashMap<_, _> = v
554             .into_iter()
555             .map(|(k, v)| (k, ManuallyDrop::new(v)))
556             .collect();
557         let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
558 
559         assert_eq!(v.len(), res.len());
560         for (k, v) in v.iter() {
561             assert_eq!(
562                 res.get(k).unwrap().as_raw_descriptor(),
563                 v.as_raw_descriptor()
564             );
565         }
566     }
567 }
568