• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
6 //!
7 //! # Example
8 //!
9 //! ```
10 //! use serde_json::to_string;
11 //! use sys_util::{
12 //!     FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
13 //!     deserialize_with_descriptors,
14 //! };
15 //! use tempfile::tempfile;
16 //!
17 //! let tmp_f = tempfile().unwrap();
18 //!
19 //! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
20 //! let data = FileSerdeWrapper(tmp_f);
21 //!
22 //! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
23 //! let data_wrapper = SerializeDescriptors::new(&data);
24 //!
25 //! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
26 //! // as the data structures are serialized by the serializer.
27 //! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
28 //!
29 //! // If data_wrapper contains any side channel descriptor refs
30 //! // (it contains tmp_f in this case), we can retrieve the actual descriptors
31 //! // from the side channel using into_descriptors().
32 //! let out_descriptors = data_wrapper.into_descriptors();
33 //!
34 //! // When sending out_json over some transport, also send out_descriptors.
35 //!
36 //! // For this example, we aren't really transporting data across the process, but we do need to
37 //! // convert the descriptor type.
38 //! let mut safe_descriptors = out_descriptors
39 //!     .iter()
40 //!     .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
41 //!     .collect();
42 //! std::mem::forget(data); // Prevent double drop of tmp_f.
43 //!
44 //! // The deserialize_with_descriptors function is used give the descriptor deserializers access
45 //! // to side channel descriptors.
46 //! let res: FileSerdeWrapper =
47 //!     deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors)
48 //!        .expect("failed to deserialize");
49 //! ```
50 
51 use std::{
52     cell::{Cell, RefCell},
53     convert::TryInto,
54     fmt,
55     fs::File,
56     ops::{Deref, DerefMut},
57     panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
58 };
59 
60 use serde::{
61     de::{
62         Error, Visitor, {self},
63     },
64     ser, Deserialize, Deserializer, Serialize, Serializer,
65 };
66 
67 use super::{RawDescriptor, SafeDescriptor};
68 
69 thread_local! {
70     static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
71 }
72 
73 /// Initializes the thread local storage for descriptor serialization. Fails if it was already
74 /// initialized without an intervening `take_descriptor_dst` on this thread.
init_descriptor_dst() -> Result<(), &'static str>75 fn init_descriptor_dst() -> Result<(), &'static str> {
76     DESCRIPTOR_DST.with(|d| {
77         let mut descriptors = d.borrow_mut();
78         if descriptors.is_some() {
79             return Err(
80                 "attempt to initialize descriptor destination that was already initialized",
81             );
82         }
83         *descriptors = Some(Default::default());
84         Ok(())
85     })
86 }
87 
88 /// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
89 /// to `init_descriptor_dst` on this thread.
take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str>90 fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
91     match DESCRIPTOR_DST.with(|d| d.replace(None)) {
92         Some(d) => Ok(d),
93         None => Err("attempt to take descriptor destination before it was initialized"),
94     }
95 }
96 
97 /// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
98 /// the descriptor was pushed.
99 //
100 /// Returns Err if the thread local destination was not already initialized.
push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str>101 fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
102     DESCRIPTOR_DST.with(|d| {
103         d.borrow_mut()
104             .as_mut()
105             .ok_or("attempt to serialize descriptor without descriptor destination")
106             .map(|descriptors| {
107                 let index = descriptors.len();
108                 descriptors.push(rd);
109                 index
110             })
111     })
112 }
113 
114 /// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
115 ///
116 /// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
117 ///
118 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
119 /// "...")]` attribute which will make use of this function.
serialize_descriptor<S: Serializer>( rd: &RawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>120 pub fn serialize_descriptor<S: Serializer>(
121     rd: &RawDescriptor,
122     se: S,
123 ) -> std::result::Result<S::Ok, S::Error> {
124     let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
125     se.serialize_u32(
126         index
127             .try_into()
128             .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
129     )
130 }
131 
132 /// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
133 /// given to an ordinary `Serializer`.
134 ///
135 /// This is the corresponding type to use for serialization before using
136 /// `deserialize_with_descriptors`.
137 ///
138 /// # Examples
139 ///
140 /// ```
141 /// use serde_json::to_string;
142 /// use sys_util::{FileSerdeWrapper, SerializeDescriptors};
143 /// use tempfile::tempfile;
144 ///
145 /// let tmp_f = tempfile().unwrap();
146 /// let data = FileSerdeWrapper(tmp_f);
147 /// let data_wrapper = SerializeDescriptors::new(&data);
148 ///
149 /// // Serializes `v` as normal...
150 /// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
151 /// // If `serialize_descriptor` was called, we can capture the descriptors from here.
152 /// let out_descriptors = data_wrapper.into_descriptors();
153 /// ```
154 pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
155 
156 impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
new(inner: &'a T) -> Self157     pub fn new(inner: &'a T) -> Self {
158         Self(inner, Default::default())
159     }
160 
into_descriptors(self) -> Vec<RawDescriptor>161     pub fn into_descriptors(self) -> Vec<RawDescriptor> {
162         self.1.into_inner()
163     }
164 }
165 
166 impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,167     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168     where
169         S: Serializer,
170     {
171         init_descriptor_dst().map_err(ser::Error::custom)?;
172 
173         // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
174         // take_descriptor_dst afterwards.
175         let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
176         self.1.set(take_descriptor_dst().unwrap());
177         match res {
178             Ok(r) => r,
179             Err(e) => resume_unwind(e),
180         }
181     }
182 }
183 
184 thread_local! {
185     static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
186 }
187 
188 /// Sets the thread local storage of descriptors for deserialization. Fails if this was already
189 /// called without a call to `take_descriptor_src` on this thread.
190 ///
191 /// This is given as a collection of `Option` so that unused descriptors can be returned.
set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str>192 fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
193     DESCRIPTOR_SRC.with(|d| {
194         let mut src = d.borrow_mut();
195         if src.is_some() {
196             return Err("attempt to set descriptor source that was already set");
197         }
198         *src = Some(descriptors);
199         Ok(())
200     })
201 }
202 
203 /// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
204 /// already taken or never set with `set_descriptor_src`.
205 ///
206 /// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
207 /// were unused.
take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str>208 fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
209     DESCRIPTOR_SRC.with(|d| {
210         d.replace(None)
211             .ok_or("attempt to take descriptor source which was never set")
212     })
213 }
214 
215 /// Takes a descriptor at the given index from the thread local source of descriptors.
216 //
217 /// Returns None if the thread local source was not already initialized.
take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str>218 fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
219     DESCRIPTOR_SRC.with(|d| {
220         d.borrow_mut()
221             .as_mut()
222             .ok_or("attempt to deserialize descriptor without descriptor source")?
223             .get_mut(index)
224             .ok_or("attempt to deserialize out of bounds descriptor")?
225             .take()
226             .ok_or("attempt to deserialize descriptor that was already taken")
227     })
228 }
229 
230 /// Deserializes a descriptor provided via `deserialize_with_descriptors`.
231 ///
232 /// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
233 ///
234 /// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
235 /// "...")]` attribute which will make use of this function.
deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error> where D: Deserializer<'de>,236 pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
237 where
238     D: Deserializer<'de>,
239 {
240     struct DescriptorVisitor;
241 
242     impl<'de> Visitor<'de> for DescriptorVisitor {
243         type Value = u32;
244 
245         fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
246             formatter.write_str("an integer which fits into a u32")
247         }
248 
249         fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
250             Ok(value as _)
251         }
252 
253         fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
254             Ok(value as _)
255         }
256 
257         fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
258             Ok(value)
259         }
260 
261         fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
262             value.try_into().map_err(E::custom)
263         }
264 
265         fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
266             value.try_into().map_err(E::custom)
267         }
268 
269         fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
270             value.try_into().map_err(E::custom)
271         }
272 
273         fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
274             value.try_into().map_err(E::custom)
275         }
276 
277         fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
278             value.try_into().map_err(E::custom)
279         }
280 
281         fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
282             value.try_into().map_err(E::custom)
283         }
284 
285         fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
286             value.try_into().map_err(E::custom)
287         }
288     }
289 
290     let index = de.deserialize_u32(DescriptorVisitor)? as usize;
291     take_descriptor(index).map_err(D::Error::custom)
292 }
293 
294 /// Allows the use of any serde deserializer within a closure while providing access to the a set of
295 /// descriptors for use in `deserialize_descriptor`.
296 ///
297 /// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
298 ///
299 /// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
300 /// error.
deserialize_with_descriptors<F, T, E>( f: F, descriptors: &mut Vec<Option<SafeDescriptor>>, ) -> Result<T, E> where F: FnOnce() -> Result<T, E>, E: de::Error,301 pub fn deserialize_with_descriptors<F, T, E>(
302     f: F,
303     descriptors: &mut Vec<Option<SafeDescriptor>>,
304 ) -> Result<T, E>
305 where
306     F: FnOnce() -> Result<T, E>,
307     E: de::Error,
308 {
309     let swap_descriptors = std::mem::take(descriptors);
310     set_descriptor_src(swap_descriptors).map_err(E::custom)?;
311 
312     // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
313     // take_descriptor_src afterwards.
314     let res = catch_unwind(AssertUnwindSafe(f));
315 
316     // unwrap is used because set_descriptor_src is always called before this, so it should never
317     // panic.
318     *descriptors = take_descriptor_src().unwrap();
319 
320     match res {
321         Ok(r) => r,
322         Err(e) => resume_unwind(e),
323     }
324 }
325 
326 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
327 /// attribute. It only works with fields with `RawDescriptor` type.
328 ///
329 /// # Examples
330 ///
331 /// ```
332 /// use serde::{Deserialize, Serialize};
333 /// use sys_util::RawDescriptor;
334 ///
335 /// #[derive(Serialize, Deserialize)]
336 /// struct RawContainer {
337 ///     #[serde(with = "sys_util::with_raw_descriptor")]
338 ///     rd: RawDescriptor,
339 /// }
340 /// ```
341 pub mod with_raw_descriptor {
342     use super::super::{IntoRawDescriptor, RawDescriptor};
343     use serde::Deserializer;
344 
345     pub use super::serialize_descriptor as serialize;
346 
deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error> where D: Deserializer<'de>,347     pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
348     where
349         D: Deserializer<'de>,
350     {
351         super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
352     }
353 }
354 
355 /// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
356 /// attribute.
357 ///
358 /// # Examples
359 ///
360 /// ```
361 /// use std::fs::File;
362 /// use serde::{Deserialize, Serialize};
363 /// use sys_util::RawDescriptor;
364 ///
365 /// #[derive(Serialize, Deserialize)]
366 /// struct FileContainer {
367 ///     #[serde(with = "sys_util::with_as_descriptor")]
368 ///     file: File,
369 /// }
370 /// ```
371 pub mod with_as_descriptor {
372     use super::super::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor};
373     use serde::{Deserializer, Serializer};
374 
serialize<S: Serializer>( rd: &dyn AsRawDescriptor, se: S, ) -> std::result::Result<S::Ok, S::Error>375     pub fn serialize<S: Serializer>(
376         rd: &dyn AsRawDescriptor,
377         se: S,
378     ) -> std::result::Result<S::Ok, S::Error> {
379         super::serialize_descriptor(&rd.as_raw_descriptor(), se)
380     }
381 
deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error> where D: Deserializer<'de>, T: FromRawDescriptor,382     pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
383     where
384         D: Deserializer<'de>,
385         T: FromRawDescriptor,
386     {
387         super::deserialize_descriptor(de)
388             .map(IntoRawDescriptor::into_raw_descriptor)
389             .map(|rd| unsafe { T::from_raw_descriptor(rd) })
390     }
391 }
392 
393 /// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
394 /// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
395 /// `Option<File>`.
396 #[derive(Serialize, Deserialize)]
397 #[serde(transparent)]
398 pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
399 
400 impl fmt::Debug for FileSerdeWrapper {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result401     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
402         self.0.fmt(f)
403     }
404 }
405 
406 impl From<File> for FileSerdeWrapper {
from(file: File) -> Self407     fn from(file: File) -> Self {
408         FileSerdeWrapper(file)
409     }
410 }
411 
412 impl From<FileSerdeWrapper> for File {
from(f: FileSerdeWrapper) -> File413     fn from(f: FileSerdeWrapper) -> File {
414         f.0
415     }
416 }
417 
418 impl Deref for FileSerdeWrapper {
419     type Target = File;
deref(&self) -> &Self::Target420     fn deref(&self) -> &Self::Target {
421         &self.0
422     }
423 }
424 
425 impl DerefMut for FileSerdeWrapper {
deref_mut(&mut self) -> &mut Self::Target426     fn deref_mut(&mut self) -> &mut Self::Target {
427         &mut self.0
428     }
429 }
430 
431 #[cfg(test)]
432 mod tests {
433     use super::super::{
434         deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper,
435         FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors,
436     };
437 
438     use std::{collections::HashMap, fs::File, mem::ManuallyDrop, os::unix::io::AsRawFd};
439 
440     use serde::{de::DeserializeOwned, Deserialize, Serialize};
441     use tempfile::tempfile;
442 
deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T443     fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
444         let mut safe_descriptors = descriptors
445             .iter()
446             .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
447             .collect();
448 
449         let res =
450             deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors)
451                 .unwrap();
452 
453         assert!(safe_descriptors.iter().all(|v| v.is_none()));
454 
455         res
456     }
457 
458     #[test]
raw()459     fn raw() {
460         #[derive(Serialize, Deserialize, PartialEq, Debug)]
461         struct RawContainer {
462             #[serde(with = "with_raw_descriptor")]
463             rd: RawDescriptor,
464         }
465         // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
466         // descriptors for this test.
467         let fake_rd = 5_123_457_i32;
468         let v = RawContainer { rd: fake_rd };
469         let v_serialize = SerializeDescriptors::new(&v);
470         let json = serde_json::to_string(&v_serialize).unwrap();
471         let descriptors = v_serialize.into_descriptors();
472         let res = deserialize(&json, &descriptors);
473         assert_eq!(v, res);
474     }
475 
476     #[test]
file()477     fn file() {
478         #[derive(Serialize, Deserialize)]
479         struct FileContainer {
480             #[serde(with = "with_as_descriptor")]
481             file: File,
482         }
483 
484         let v = FileContainer {
485             file: tempfile().unwrap(),
486         };
487         let v_serialize = SerializeDescriptors::new(&v);
488         let json = serde_json::to_string(&v_serialize).unwrap();
489         let descriptors = v_serialize.into_descriptors();
490         let v = ManuallyDrop::new(v);
491         let res: FileContainer = deserialize(&json, &descriptors);
492         assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd());
493     }
494 
495     #[test]
option()496     fn option() {
497         #[derive(Serialize, Deserialize)]
498         struct TestOption {
499             a: Option<FileSerdeWrapper>,
500             b: Option<FileSerdeWrapper>,
501         }
502 
503         let v = TestOption {
504             a: None,
505             b: Some(tempfile().unwrap().into()),
506         };
507         let v_serialize = SerializeDescriptors::new(&v);
508         let json = serde_json::to_string(&v_serialize).unwrap();
509         let descriptors = v_serialize.into_descriptors();
510         let v = ManuallyDrop::new(v);
511         let res: TestOption = deserialize(&json, &descriptors);
512         assert!(res.a.is_none());
513         assert!(res.b.is_some());
514         assert_eq!(
515             v.b.as_ref().unwrap().as_raw_fd(),
516             res.b.unwrap().as_raw_fd()
517         );
518     }
519 
520     #[test]
map()521     fn map() {
522         let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
523         v.insert("a".into(), tempfile().unwrap().into());
524         v.insert("b".into(), tempfile().unwrap().into());
525         v.insert("c".into(), tempfile().unwrap().into());
526         let v_serialize = SerializeDescriptors::new(&v);
527         let json = serde_json::to_string(&v_serialize).unwrap();
528         let descriptors = v_serialize.into_descriptors();
529         // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
530         // done this way to prevent a double close of the files (which should reside in `res`)
531         // without triggering the leak sanitizer on `v`'s HashMap heap memory.
532         let v: HashMap<_, _> = v
533             .into_iter()
534             .map(|(k, v)| (k, ManuallyDrop::new(v)))
535             .collect();
536         let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
537 
538         assert_eq!(v.len(), res.len());
539         for (k, v) in v.iter() {
540             assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd());
541         }
542     }
543 }
544