• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::Span;
2 use quote::{ToTokens, TokenStreamExt};
3 use serde::{Deserialize, Serialize};
4 use syn::Token;
5 
6 use std::fmt;
7 use std::ops::ControlFlow;
8 use std::str::FromStr;
9 
10 use super::{
11     Attrs, Docs, Enum, Ident, Lifetime, LifetimeEnv, LifetimeTransitivity, Method, NamedLifetime,
12     OpaqueType, Path, RustLink, Struct, Trait,
13 };
14 use crate::Env;
15 
16 /// A type declared inside a Diplomat-annotated module.
17 #[derive(Clone, Serialize, Debug, Hash, PartialEq, Eq)]
18 #[non_exhaustive]
19 pub enum CustomType {
20     /// A non-opaque struct whose fields will be visible across the FFI boundary.
21     Struct(Struct),
22     /// A type annotated with [`diplomat::opaque`] whose fields are not visible.
23     Opaque(OpaqueType),
24     /// A fieldless enum.
25     Enum(Enum),
26 }
27 
28 impl CustomType {
29     /// Get the name of the custom type, which is unique within a module.
name(&self) -> &Ident30     pub fn name(&self) -> &Ident {
31         match self {
32             CustomType::Struct(strct) => &strct.name,
33             CustomType::Opaque(strct) => &strct.name,
34             CustomType::Enum(enm) => &enm.name,
35         }
36     }
37 
38     /// Get the methods declared in impls of the custom type.
methods(&self) -> &Vec<Method>39     pub fn methods(&self) -> &Vec<Method> {
40         match self {
41             CustomType::Struct(strct) => &strct.methods,
42             CustomType::Opaque(strct) => &strct.methods,
43             CustomType::Enum(enm) => &enm.methods,
44         }
45     }
46 
attrs(&self) -> &Attrs47     pub fn attrs(&self) -> &Attrs {
48         match self {
49             CustomType::Struct(strct) => &strct.attrs,
50             CustomType::Opaque(strct) => &strct.attrs,
51             CustomType::Enum(enm) => &enm.attrs,
52         }
53     }
54 
55     /// Get the doc lines of the custom type.
docs(&self) -> &Docs56     pub fn docs(&self) -> &Docs {
57         match self {
58             CustomType::Struct(strct) => &strct.docs,
59             CustomType::Opaque(strct) => &strct.docs,
60             CustomType::Enum(enm) => &enm.docs,
61         }
62     }
63 
64     /// Get all rust links on this type and its methods
all_rust_links(&self) -> impl Iterator<Item = &RustLink> + '_65     pub fn all_rust_links(&self) -> impl Iterator<Item = &RustLink> + '_ {
66         [self.docs()]
67             .into_iter()
68             .chain(self.methods().iter().map(|m| m.docs()))
69             .flat_map(|d| d.rust_links().iter())
70     }
71 
self_path(&self, in_path: &Path) -> Path72     pub fn self_path(&self, in_path: &Path) -> Path {
73         in_path.sub_path(self.name().clone())
74     }
75 
76     /// Get the lifetimes of the custom type.
lifetimes(&self) -> Option<&LifetimeEnv>77     pub fn lifetimes(&self) -> Option<&LifetimeEnv> {
78         match self {
79             CustomType::Struct(strct) => Some(&strct.lifetimes),
80             CustomType::Opaque(strct) => Some(&strct.lifetimes),
81             CustomType::Enum(_) => None,
82         }
83     }
84 }
85 
86 /// A symbol declared in a module, which can either be a pointer to another path,
87 /// or a custom type defined directly inside that module
88 #[derive(Clone, Serialize, Debug)]
89 #[non_exhaustive]
90 pub enum ModSymbol {
91     /// A symbol that is a pointer to another path.
92     Alias(Path),
93     /// A symbol that is a submodule.
94     SubModule(Ident),
95     /// A symbol that is a custom type.
96     CustomType(CustomType),
97     /// A trait
98     Trait(Trait),
99 }
100 
101 /// A named type that is just a path, e.g. `std::borrow::Cow<'a, T>`.
102 #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
103 #[non_exhaustive]
104 pub struct PathType {
105     pub path: Path,
106     pub lifetimes: Vec<Lifetime>,
107 }
108 
109 impl PathType {
to_syn(&self) -> syn::TypePath110     pub fn to_syn(&self) -> syn::TypePath {
111         let mut path = self.path.to_syn();
112 
113         if !self.lifetimes.is_empty() {
114             if let Some(seg) = path.segments.last_mut() {
115                 let lifetimes = &self.lifetimes;
116                 seg.arguments =
117                     syn::PathArguments::AngleBracketed(syn::parse_quote! { <#(#lifetimes),*> });
118             }
119         }
120 
121         syn::TypePath { qself: None, path }
122     }
123 
new(path: Path) -> Self124     pub fn new(path: Path) -> Self {
125         Self {
126             path,
127             lifetimes: vec![],
128         }
129     }
130 
131     /// Get the `Self` type from a struct declaration.
132     ///
133     /// Consider the following struct declaration:
134     /// ```
135     /// struct RefList<'a> {
136     ///     data: &'a i32,
137     ///     next: Option<Box<Self>>,
138     /// }
139     /// ```
140     /// When determining what type `Self` is in the `next` field, we would have to call
141     /// this method on the `syn::ItemStruct` that represents this struct declaration.
142     /// This method would then return a `PathType` representing `RefList<'a>`, so we
143     /// know that's what `Self` should refer to.
144     ///
145     /// The reason this function exists though is so when we convert the fields' types
146     /// to `PathType`s, we don't panic. We don't actually need to write the struct's
147     /// field types expanded in the macro, so this function is more for correctness,
extract_self_type(strct: &syn::ItemStruct) -> Self148     pub fn extract_self_type(strct: &syn::ItemStruct) -> Self {
149         let self_name = (&strct.ident).into();
150 
151         PathType {
152             path: Path {
153                 elements: vec![self_name],
154             },
155             lifetimes: strct
156                 .generics
157                 .lifetimes()
158                 .map(|lt_def| (&lt_def.lifetime).into())
159                 .collect(),
160         }
161     }
162 
163     /// If this is a [`TypeName::Named`], grab the [`CustomType`] it points to from
164     /// the `env`, which contains all [`CustomType`]s across all FFI modules.
165     ///
166     /// Also returns the path the CustomType is in (useful for resolving fields)
resolve_with_path<'a>(&self, in_path: &Path, env: &'a Env) -> (Path, &'a CustomType)167     pub fn resolve_with_path<'a>(&self, in_path: &Path, env: &'a Env) -> (Path, &'a CustomType) {
168         let local_path = &self.path;
169         let mut cur_path = in_path.clone();
170         for (i, elem) in local_path.elements.iter().enumerate() {
171             match elem.as_str() {
172                 "crate" => {
173                     // TODO(#34): get the name of enclosing crate from env when we support multiple crates
174                     cur_path = Path::empty()
175                 }
176 
177                 "super" => cur_path = cur_path.get_super(),
178 
179                 o => match env.get(&cur_path, o) {
180                     Some(ModSymbol::Alias(p)) => {
181                         let mut remaining_elements: Vec<Ident> =
182                             local_path.elements.iter().skip(i + 1).cloned().collect();
183                         let mut new_path = p.elements.clone();
184                         new_path.append(&mut remaining_elements);
185                         return PathType::new(Path { elements: new_path })
186                             .resolve_with_path(&cur_path.clone(), env);
187                     }
188                     Some(ModSymbol::SubModule(name)) => {
189                         cur_path.elements.push(name.clone());
190                     }
191                     Some(ModSymbol::CustomType(t)) => {
192                         if i == local_path.elements.len() - 1 {
193                             return (cur_path, t);
194                         } else {
195                             panic!(
196                                 "Unexpected custom type when resolving symbol {} in {}",
197                                 o,
198                                 cur_path.elements.join("::")
199                             )
200                         }
201                     }
202                     Some(ModSymbol::Trait(trt)) => {
203                         panic!("Found trait {} but expected a type", trt.name);
204                     }
205                     None => panic!(
206                         "Could not resolve symbol {} in {}",
207                         o,
208                         cur_path.elements.join("::")
209                     ),
210                 },
211             }
212         }
213 
214         panic!(
215             "Path {} does not point to a custom type",
216             in_path.elements.join("::")
217         )
218     }
219 
220     /// If this is a [`TypeName::Named`], grab the [`CustomType`] it points to from
221     /// the `env`, which contains all [`CustomType`]s across all FFI modules.
222     ///
223     /// If you need to resolve struct fields later, call [`Self::resolve_with_path()`] instead
224     /// to get the path to resolve the fields in.
resolve<'a>(&self, in_path: &Path, env: &'a Env) -> &'a CustomType225     pub fn resolve<'a>(&self, in_path: &Path, env: &'a Env) -> &'a CustomType {
226         self.resolve_with_path(in_path, env).1
227     }
228 
trait_to_syn(&self) -> syn::TraitBound229     pub fn trait_to_syn(&self) -> syn::TraitBound {
230         let mut path = self.path.to_syn();
231 
232         if !self.lifetimes.is_empty() {
233             if let Some(seg) = path.segments.last_mut() {
234                 let lifetimes = &self.lifetimes;
235                 seg.arguments =
236                     syn::PathArguments::AngleBracketed(syn::parse_quote! { <#(#lifetimes),*> });
237             }
238         }
239         syn::TraitBound {
240             paren_token: None,
241             modifier: syn::TraitBoundModifier::None,
242             lifetimes: None, // todo this is an assumption
243             path,
244         }
245     }
246 
resolve_trait_with_path<'a>(&self, in_path: &Path, env: &'a Env) -> (Path, Trait)247     pub fn resolve_trait_with_path<'a>(&self, in_path: &Path, env: &'a Env) -> (Path, Trait) {
248         let local_path = &self.path;
249         let cur_path = in_path.clone();
250         for (i, elem) in local_path.elements.iter().enumerate() {
251             if let Some(ModSymbol::Trait(trt)) = env.get(&cur_path, elem.as_str()) {
252                 if i == local_path.elements.len() - 1 {
253                     return (cur_path, trt.clone());
254                 } else {
255                     panic!(
256                         "Unexpected custom trait when resolving symbol {} in {}",
257                         trt.name,
258                         cur_path.elements.join("::")
259                     )
260                 }
261             }
262         }
263 
264         panic!(
265             "Path {} does not point to a custom trait",
266             in_path.elements.join("::")
267         )
268     }
269 
270     /// If this is a [`TypeName::Named`], grab the [`CustomType`] it points to from
271     /// the `env`, which contains all [`CustomType`]s across all FFI modules.
272     ///
273     /// If you need to resolve struct fields later, call [`Self::resolve_with_path()`] instead
274     /// to get the path to resolve the fields in.
resolve_trait<'a>(&self, in_path: &Path, env: &'a Env) -> Trait275     pub fn resolve_trait<'a>(&self, in_path: &Path, env: &'a Env) -> Trait {
276         self.resolve_trait_with_path(in_path, env).1
277     }
278 }
279 
280 impl From<&syn::TypePath> for PathType {
from(other: &syn::TypePath) -> Self281     fn from(other: &syn::TypePath) -> Self {
282         let lifetimes = other
283             .path
284             .segments
285             .last()
286             .and_then(|last| {
287                 if let syn::PathArguments::AngleBracketed(angle_generics) = &last.arguments {
288                     Some(
289                         angle_generics
290                             .args
291                             .iter()
292                             .map(|generic_arg| match generic_arg {
293                                 syn::GenericArgument::Lifetime(lifetime) => lifetime.into(),
294                                 _ => panic!("generic type arguments are unsupported {other:?}"),
295                             })
296                             .collect(),
297                     )
298                 } else {
299                     None
300                 }
301             })
302             .unwrap_or_default();
303 
304         Self {
305             path: Path::from_syn(&other.path),
306             lifetimes,
307         }
308     }
309 }
310 
311 impl From<&syn::TraitBound> for PathType {
from(other: &syn::TraitBound) -> Self312     fn from(other: &syn::TraitBound) -> Self {
313         let lifetimes = other
314             .path
315             .segments
316             .last()
317             .and_then(|last| {
318                 if let syn::PathArguments::AngleBracketed(angle_generics) = &last.arguments {
319                     Some(
320                         angle_generics
321                             .args
322                             .iter()
323                             .map(|generic_arg| match generic_arg {
324                                 syn::GenericArgument::Lifetime(lifetime) => lifetime.into(),
325                                 _ => panic!("generic type arguments are unsupported {other:?}"),
326                             })
327                             .collect(),
328                     )
329                 } else {
330                     None
331                 }
332             })
333             .unwrap_or_default();
334 
335         Self {
336             path: Path::from_syn(&other.path),
337             lifetimes,
338         }
339     }
340 }
341 
342 impl From<Path> for PathType {
from(other: Path) -> Self343     fn from(other: Path) -> Self {
344         PathType::new(other)
345     }
346 }
347 
348 #[derive(Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
349 #[allow(clippy::exhaustive_enums)] // there are only two kinds of mutability we care about
350 pub enum Mutability {
351     Mutable,
352     Immutable,
353 }
354 
355 impl Mutability {
to_syn(&self) -> Option<Token![mut]>356     pub fn to_syn(&self) -> Option<Token![mut]> {
357         match self {
358             Mutability::Mutable => Some(syn::token::Mut(Span::call_site())),
359             Mutability::Immutable => None,
360         }
361     }
362 
from_syn(t: &Option<Token![mut]>) -> Self363     pub fn from_syn(t: &Option<Token![mut]>) -> Self {
364         match t {
365             Some(_) => Mutability::Mutable,
366             None => Mutability::Immutable,
367         }
368     }
369 
370     /// Returns `true` if `&self` is the mutable variant, otherwise `false`.
is_mutable(&self) -> bool371     pub fn is_mutable(&self) -> bool {
372         matches!(self, Mutability::Mutable)
373     }
374 
375     /// Returns `true` if `&self` is the immutable variant, otherwise `false`.
is_immutable(&self) -> bool376     pub fn is_immutable(&self) -> bool {
377         matches!(self, Mutability::Immutable)
378     }
379 
380     /// Shorthand ternary operator for choosing a value based on whether
381     /// a `Mutability` is mutable or immutable.
382     ///
383     /// The following pattern (with very slight variations) shows up often in code gen:
384     /// ```ignore
385     /// if mutability.is_mutable() {
386     ///     ""
387     /// } else {
388     ///     "const "
389     /// }
390     /// ```
391     /// This is particularly annoying in `write!(...)` statements, where `cargo fmt`
392     /// expands it to take up 5 lines.
393     ///
394     /// This method offers a 1-line alternative:
395     /// ```ignore
396     /// mutability.if_mut_else("", "const ")
397     /// ```
398     /// For cases where lazy evaluation is desired, consider using a conditional
399     /// or a `match` statement.
if_mut_else<T>(&self, if_mut: T, if_immut: T) -> T400     pub fn if_mut_else<T>(&self, if_mut: T, if_immut: T) -> T {
401         match self {
402             Mutability::Mutable => if_mut,
403             Mutability::Immutable => if_immut,
404         }
405     }
406 }
407 
408 /// For types like `Result`/`DiplomatResult`, `&[T]`/`DiplomatSlice<T>` which can be
409 /// specified using (non-ffi-safe) Rust stdlib types, or FFI-safe `repr(C)` types from
410 /// `diplomat_runtime`, this tracks which of the two were used.
411 #[derive(Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
412 #[allow(clippy::exhaustive_enums)] // This can only have two values
413 pub enum StdlibOrDiplomat {
414     Stdlib,
415     Diplomat,
416 }
417 
418 /// A local type reference, such as the type of a field, parameter, or return value.
419 /// Unlike [`CustomType`], which represents a type declaration, [`TypeName`]s can compose
420 /// types through references and boxing, and can also capture unresolved paths.
421 #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
422 #[non_exhaustive]
423 pub enum TypeName {
424     /// A built-in Rust scalar primitive.
425     Primitive(PrimitiveType),
426     /// An unresolved path to a custom type, which can be resolved after all types
427     /// are collected with [`TypeName::resolve()`].
428     Named(PathType),
429     /// An optionally mutable reference to another type.
430     Reference(Lifetime, Mutability, Box<TypeName>),
431     /// A `Box<T>` type.
432     Box(Box<TypeName>),
433     /// An `Option<T>` or DiplomatOption type.
434     Option(Box<TypeName>, StdlibOrDiplomat),
435     /// A `Result<T, E>` or `diplomat_runtime::DiplomatResult` type.
436     Result(Box<TypeName>, Box<TypeName>, StdlibOrDiplomat),
437     Write,
438     /// A `&DiplomatStr` or `Box<DiplomatStr>` type.
439     /// Owned strings don't have a lifetime.
440     ///
441     /// If StdlibOrDiplomat::Stdlib, it's specified using Rust pointer types (&T, Box<T>),
442     /// if StdlibOrDiplomat::Diplomat, it's specified using DiplomatStrSlice, etc
443     StrReference(Option<Lifetime>, StringEncoding, StdlibOrDiplomat),
444     /// A `&[T]` or `Box<[T]>` type, where `T` is a primitive.
445     /// Owned slices don't have a lifetime or mutability.
446     ///
447     /// If StdlibOrDiplomat::Stdlib, it's specified using Rust pointer types (&T, Box<T>),
448     /// if StdlibOrDiplomat::Diplomat, it's specified using DiplomatSlice/DiplomatOwnedSlice/DiplomatSliceMut
449     PrimitiveSlice(
450         Option<(Lifetime, Mutability)>,
451         PrimitiveType,
452         StdlibOrDiplomat,
453     ),
454     /// `&[DiplomatStrSlice]`, etc. Equivalent to `&[&str]`
455     ///
456     /// If StdlibOrDiplomat::Stdlib, it's specified as `&[&DiplomatFoo]`, if StdlibOrDiplomat::Diplomat it's specified
457     /// as `DiplomatSlice<&DiplomatFoo>`
458     StrSlice(StringEncoding, StdlibOrDiplomat),
459     /// The `()` type.
460     Unit,
461     /// The `Self` type.
462     SelfType(PathType),
463     /// std::cmp::Ordering or core::cmp::Ordering
464     ///
465     /// The path must be present! Ordering will be parsed as an AST type!
466     Ordering,
467     Function(Vec<Box<TypeName>>, Box<TypeName>),
468     ImplTrait(PathType),
469 }
470 
471 #[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug, Copy)]
472 #[non_exhaustive]
473 pub enum StringEncoding {
474     UnvalidatedUtf8,
475     UnvalidatedUtf16,
476     /// The caller guarantees that they're passing valid UTF-8, under penalty of UB
477     Utf8,
478 }
479 
480 impl StringEncoding {
481     /// Get the diplomat slice type when specified using diplomat_runtime types
get_diplomat_slice_type(self, lt: &Option<Lifetime>) -> syn::Type482     pub fn get_diplomat_slice_type(self, lt: &Option<Lifetime>) -> syn::Type {
483         if let Some(ref lt) = *lt {
484             let lt = LifetimeGenericsListDisplay(lt);
485 
486             match self {
487                 Self::UnvalidatedUtf8 => {
488                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatStrSlice #lt)
489                 }
490                 Self::UnvalidatedUtf16 => {
491                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatStr16Slice #lt)
492                 }
493                 Self::Utf8 => {
494                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatUtf8StrSlice #lt)
495                 }
496             }
497         } else {
498             match self {
499                 Self::UnvalidatedUtf8 => {
500                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatOwnedStrSlice)
501                 }
502                 Self::UnvalidatedUtf16 => {
503                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatOwnedStr16Slice)
504                 }
505                 Self::Utf8 => {
506                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatOwnedUTF8StrSlice)
507                 }
508             }
509         }
510     }
511 
get_diplomat_slice_type_str(self) -> &'static str512     fn get_diplomat_slice_type_str(self) -> &'static str {
513         match self {
514             StringEncoding::Utf8 => "str",
515             StringEncoding::UnvalidatedUtf8 => "DiplomatStr",
516             StringEncoding::UnvalidatedUtf16 => "DiplomatStr16",
517         }
518     }
519     /// Get slice type when specified using rust stdlib types
get_stdlib_slice_type(self, lt: &Option<Lifetime>) -> syn::Type520     pub fn get_stdlib_slice_type(self, lt: &Option<Lifetime>) -> syn::Type {
521         let inner = match self {
522             Self::UnvalidatedUtf8 => quote::quote!(DiplomatStr),
523             Self::UnvalidatedUtf16 => quote::quote!(DiplomatStr16),
524             Self::Utf8 => quote::quote!(str),
525         };
526         if let Some(ref lt) = *lt {
527             let lt = ReferenceDisplay(lt, &Mutability::Immutable);
528 
529             syn::parse_quote_spanned!(Span::call_site() => #lt #inner)
530         } else {
531             syn::parse_quote_spanned!(Span::call_site() => Box<#inner>)
532         }
533     }
get_stdlib_slice_type_str(self) -> &'static str534     pub fn get_stdlib_slice_type_str(self) -> &'static str {
535         match self {
536             StringEncoding::Utf8 => "DiplomatUtf8Str",
537             StringEncoding::UnvalidatedUtf8 => "DiplomatStrSlice",
538             StringEncoding::UnvalidatedUtf16 => "DiplomatStr16Slice",
539         }
540     }
541 }
542 
get_lifetime_from_syn_path(p: &syn::TypePath) -> Lifetime543 fn get_lifetime_from_syn_path(p: &syn::TypePath) -> Lifetime {
544     if let syn::PathArguments::AngleBracketed(ref generics) =
545         p.path.segments[p.path.segments.len() - 1].arguments
546     {
547         if let Some(syn::GenericArgument::Lifetime(lt)) = generics.args.first() {
548             return Lifetime::from(lt);
549         }
550     }
551     Lifetime::Anonymous
552 }
553 
get_ty_from_syn_path(p: &syn::TypePath) -> Option<&syn::Type>554 fn get_ty_from_syn_path(p: &syn::TypePath) -> Option<&syn::Type> {
555     if let syn::PathArguments::AngleBracketed(ref generics) =
556         p.path.segments[p.path.segments.len() - 1].arguments
557     {
558         for gen in generics.args.iter() {
559             if let syn::GenericArgument::Type(ref ty) = gen {
560                 return Some(ty);
561             }
562         }
563     }
564     None
565 }
566 
567 impl TypeName {
568     /// Is this type safe to be passed across the FFI boundary?
569     ///
570     /// This also marks DiplomatOption<&T> as FFI-unsafe: these are technically safe from an ABI standpoint
571     /// however Diplomat always expects these to be equivalent to a nullable pointer, so Option<&T> is required.
is_ffi_safe(&self) -> bool572     pub fn is_ffi_safe(&self) -> bool {
573         match self {
574             TypeName::Primitive(..) | TypeName::Named(_) | TypeName::SelfType(_) | TypeName::Reference(..) |
575             TypeName::Box(..) |
576             // can only be passed across the FFI boundary; callbacks and traits are input-only
577             TypeName::Function(..) | TypeName::ImplTrait(..) |
578             // These are specified using FFI-safe diplomat_runtime types
579             TypeName::StrReference(.., StdlibOrDiplomat::Diplomat) | TypeName::StrSlice(.., StdlibOrDiplomat::Diplomat) |TypeName::PrimitiveSlice(.., StdlibOrDiplomat::Diplomat) => true,
580             // These are special anyway and shouldn't show up in structs
581             TypeName::Unit | TypeName::Write | TypeName::Result(..) |
582             // This is basically only useful in return types
583             TypeName::Ordering |
584             // These are specified using Rust stdlib types and not safe across FFI
585             TypeName::StrReference(.., StdlibOrDiplomat::Stdlib) | TypeName::StrSlice(.., StdlibOrDiplomat::Stdlib) | TypeName::PrimitiveSlice(.., StdlibOrDiplomat::Stdlib)  => false,
586             TypeName::Option(inner, stdlib) => match **inner {
587                 // Option<&T>/Option<Box<T>> are the ffi-safe way to specify options
588                 TypeName::Reference(..) | TypeName::Box(..) => *stdlib == StdlibOrDiplomat::Stdlib,
589                 // For other types (primitives, structs, enums) we need DiplomatOption
590                 _ => *stdlib == StdlibOrDiplomat::Diplomat,
591              }
592         }
593     }
594 
595     /// What's the FFI safe version of this type?
596     ///
597     /// This also marks DiplomatOption<&T> as FFI-unsafe: these are technically safe from an ABI standpoint
598     /// however Diplomat always expects these to be equivalent to a nullable pointer, so Option<&T> is required.
ffi_safe_version(&self) -> TypeName599     pub fn ffi_safe_version(&self) -> TypeName {
600         match self {
601             TypeName::StrReference(lt, encoding, StdlibOrDiplomat::Stdlib) => {
602                 TypeName::StrReference(lt.clone(), *encoding, StdlibOrDiplomat::Diplomat)
603             }
604             TypeName::StrSlice(encoding, StdlibOrDiplomat::Stdlib) => {
605                 TypeName::StrSlice(*encoding, StdlibOrDiplomat::Diplomat)
606             }
607             TypeName::PrimitiveSlice(ltmt, prim, StdlibOrDiplomat::Stdlib) => {
608                 TypeName::PrimitiveSlice(ltmt.clone(), *prim, StdlibOrDiplomat::Diplomat)
609             }
610             TypeName::Ordering => TypeName::Primitive(PrimitiveType::i8),
611             TypeName::Option(inner, _stdlib) => match **inner {
612                 // Option<&T>/Option<Box<T>> are the ffi-safe way to specify options
613                 TypeName::Reference(..) | TypeName::Box(..) => {
614                     TypeName::Option(inner.clone(), StdlibOrDiplomat::Stdlib)
615                 }
616                 // For other types (primitives, structs, enums) we need DiplomatOption
617                 _ => TypeName::Option(
618                     Box::new(inner.ffi_safe_version()),
619                     StdlibOrDiplomat::Diplomat,
620                 ),
621             },
622             _ => self.clone(),
623         }
624     }
625     /// Converts the [`TypeName`] back into an AST node that can be spliced into a program.
to_syn(&self) -> syn::Type626     pub fn to_syn(&self) -> syn::Type {
627         match self {
628             TypeName::Primitive(primitive) => {
629                 let primitive = primitive.to_ident();
630                 syn::parse_quote_spanned!(Span::call_site() => #primitive)
631             }
632             TypeName::Ordering => syn::parse_quote_spanned!(Span::call_site() => i8),
633             TypeName::Named(name) | TypeName::SelfType(name) => {
634                 // Self also gets expanded instead of turning into `Self` because
635                 // this code is used to generate the `extern "C"` functions, which
636                 // aren't in an impl block.
637                 let name = name.to_syn();
638                 syn::parse_quote_spanned!(Span::call_site() => #name)
639             }
640             TypeName::Reference(lifetime, mutability, underlying) => {
641                 let reference = ReferenceDisplay(lifetime, mutability);
642                 let underlying = underlying.to_syn();
643 
644                 syn::parse_quote_spanned!(Span::call_site() => #reference #underlying)
645             }
646             TypeName::Box(underlying) => {
647                 let underlying = underlying.to_syn();
648                 syn::parse_quote_spanned!(Span::call_site() => Box<#underlying>)
649             }
650             TypeName::Option(underlying, StdlibOrDiplomat::Stdlib) => {
651                 let underlying = underlying.to_syn();
652                 syn::parse_quote_spanned!(Span::call_site() => Option<#underlying>)
653             }
654             TypeName::Option(underlying, StdlibOrDiplomat::Diplomat) => {
655                 let underlying = underlying.to_syn();
656                 syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatOption<#underlying>)
657             }
658             TypeName::Result(ok, err, StdlibOrDiplomat::Stdlib) => {
659                 let ok = ok.to_syn();
660                 let err = err.to_syn();
661                 syn::parse_quote_spanned!(Span::call_site() => Result<#ok, #err>)
662             }
663             TypeName::Result(ok, err, StdlibOrDiplomat::Diplomat) => {
664                 let ok = ok.to_syn();
665                 let err = err.to_syn();
666                 syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatResult<#ok, #err>)
667             }
668             TypeName::Write => {
669                 syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatWrite)
670             }
671             TypeName::StrReference(lt, encoding, is_stdlib_type) => {
672                 if *is_stdlib_type == StdlibOrDiplomat::Stdlib {
673                     encoding.get_stdlib_slice_type(lt)
674                 } else {
675                     encoding.get_diplomat_slice_type(lt)
676                 }
677             }
678             TypeName::StrSlice(encoding, is_stdlib_type) => {
679                 if *is_stdlib_type == StdlibOrDiplomat::Stdlib {
680                     let inner = encoding.get_stdlib_slice_type(&Some(Lifetime::Anonymous));
681                     syn::parse_quote_spanned!(Span::call_site() => &[#inner])
682                 } else {
683                     let inner = encoding.get_diplomat_slice_type(&Some(Lifetime::Anonymous));
684                     syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatSlice<#inner>)
685                 }
686             }
687             TypeName::PrimitiveSlice(ltmt, primitive, is_stdlib_type) => {
688                 if *is_stdlib_type == StdlibOrDiplomat::Stdlib {
689                     primitive.get_stdlib_slice_type(ltmt)
690                 } else {
691                     primitive.get_diplomat_slice_type(ltmt)
692                 }
693             }
694 
695             TypeName::Unit => syn::parse_quote_spanned!(Span::call_site() => ()),
696             TypeName::Function(_input_types, output_type) => {
697                 let output_type = output_type.to_syn();
698                 // should be DiplomatCallback<function_output_type>
699                 syn::parse_quote_spanned!(Span::call_site() => DiplomatCallback<#output_type>)
700             }
701             TypeName::ImplTrait(trt_path) => {
702                 let trait_name =
703                     Ident::from(format!("DiplomatTraitStruct_{}", trt_path.path.elements[0]));
704                 // should be DiplomatTraitStruct_trait_name
705                 syn::parse_quote_spanned!(Span::call_site() => #trait_name)
706             }
707         }
708     }
709 
710     /// Extract a [`TypeName`] from a [`syn::Type`] AST node.
711     /// The following rules are used to infer [`TypeName`] variants:
712     /// - If the type is a path with a single element that is the name of a Rust primitive, returns a [`TypeName::Primitive`]
713     /// - If the type is a path with a single element [`Box`], returns a [`TypeName::Box`] with the type parameter recursively converted
714     /// - If the type is a path with a single element [`Option`], returns a [`TypeName::Option`] with the type parameter recursively converted
715     /// - If the type is a path with a single element `Self` and `self_path_type` is provided, returns a [`TypeName::Named`]
716     /// - If the type is a path with a single element [`Result`], returns a [`TypeName::Result`] with the type parameters recursively converted
717     /// - If the type is a path equal to [`diplomat_runtime::DiplomatResult`], returns a [`TypeName::DiplomatResult`] with the type parameters recursively converted
718     /// - If the type is a path equal to [`diplomat_runtime::DiplomatWrite`], returns a [`TypeName::Write`]
719     /// - If the type is a owned or borrowed string type, returns a [`TypeName::StrReference`]
720     /// - If the type is a owned or borrowed slice of a Rust primitive, returns a [`TypeName::PrimitiveSlice`]
721     /// - If the type is a reference (`&` or `&mut`), returns a [`TypeName::Reference`] with the referenced type recursively converted
722     /// - Otherwise, assume that the reference is to a [`CustomType`] in either the current module or another one, returns a [`TypeName::Named`]
from_syn(ty: &syn::Type, self_path_type: Option<PathType>) -> TypeName723     pub fn from_syn(ty: &syn::Type, self_path_type: Option<PathType>) -> TypeName {
724         match ty {
725             syn::Type::Reference(r) => {
726                 let lifetime = Lifetime::from(&r.lifetime);
727                 let mutability = Mutability::from_syn(&r.mutability);
728 
729                 let name = r.elem.to_token_stream().to_string();
730                 if name.starts_with("DiplomatStr") || name == "str" {
731                     if mutability.is_mutable() {
732                         panic!("mutable string references are disallowed");
733                     }
734                     if name == "DiplomatStr" {
735                         return TypeName::StrReference(
736                             Some(lifetime),
737                             StringEncoding::UnvalidatedUtf8,
738                             StdlibOrDiplomat::Stdlib,
739                         );
740                     } else if name == "DiplomatStr16" {
741                         return TypeName::StrReference(
742                             Some(lifetime),
743                             StringEncoding::UnvalidatedUtf16,
744                             StdlibOrDiplomat::Stdlib,
745                         );
746                     } else if name == "str" {
747                         return TypeName::StrReference(
748                             Some(lifetime),
749                             StringEncoding::Utf8,
750                             StdlibOrDiplomat::Stdlib,
751                         );
752                     }
753                 }
754                 if let syn::Type::Slice(slice) = &*r.elem {
755                     if let syn::Type::Path(p) = &*slice.elem {
756                         if let Some(primitive) = p
757                             .path
758                             .get_ident()
759                             .and_then(|i| PrimitiveType::from_str(i.to_string().as_str()).ok())
760                         {
761                             return TypeName::PrimitiveSlice(
762                                 Some((lifetime, mutability)),
763                                 primitive,
764                                 StdlibOrDiplomat::Stdlib,
765                             );
766                         }
767                     }
768                     if let TypeName::StrReference(
769                         Some(Lifetime::Anonymous),
770                         encoding,
771                         is_stdlib_type,
772                     ) = TypeName::from_syn(&slice.elem, self_path_type.clone())
773                     {
774                         if is_stdlib_type == StdlibOrDiplomat::Stdlib {
775                             panic!("Slice-of-slice is only supported with DiplomatRuntime slice types (DiplomatStrSlice, DiplomatStr16Slice, DiplomatUtf8StrSlice)");
776                         }
777                         return TypeName::StrSlice(encoding, StdlibOrDiplomat::Stdlib);
778                     }
779                 }
780                 TypeName::Reference(
781                     lifetime,
782                     mutability,
783                     Box::new(TypeName::from_syn(r.elem.as_ref(), self_path_type)),
784                 )
785             }
786             syn::Type::Path(p) => {
787                 let p_len = p.path.segments.len();
788                 if let Some(primitive) = p
789                     .path
790                     .get_ident()
791                     .and_then(|i| PrimitiveType::from_str(i.to_string().as_str()).ok())
792                 {
793                     TypeName::Primitive(primitive)
794                 } else if p_len >= 2
795                     && p.path.segments[p_len - 2].ident == "cmp"
796                     && p.path.segments[p_len - 1].ident == "Ordering"
797                 {
798                     TypeName::Ordering
799                 } else if p_len == 1 && p.path.segments[0].ident == "Box" {
800                     if let syn::PathArguments::AngleBracketed(type_args) =
801                         &p.path.segments[0].arguments
802                     {
803                         if let syn::GenericArgument::Type(syn::Type::Slice(slice)) =
804                             &type_args.args[0]
805                         {
806                             if let TypeName::Primitive(p) =
807                                 TypeName::from_syn(&slice.elem, self_path_type)
808                             {
809                                 TypeName::PrimitiveSlice(None, p, StdlibOrDiplomat::Stdlib)
810                             } else {
811                                 panic!("Owned slices only support primitives.")
812                             }
813                         } else if let syn::GenericArgument::Type(tpe) = &type_args.args[0] {
814                             if tpe.to_token_stream().to_string() == "DiplomatStr" {
815                                 TypeName::StrReference(
816                                     None,
817                                     StringEncoding::UnvalidatedUtf8,
818                                     StdlibOrDiplomat::Stdlib,
819                                 )
820                             } else if tpe.to_token_stream().to_string() == "DiplomatStr16" {
821                                 TypeName::StrReference(
822                                     None,
823                                     StringEncoding::UnvalidatedUtf16,
824                                     StdlibOrDiplomat::Stdlib,
825                                 )
826                             } else if tpe.to_token_stream().to_string() == "str" {
827                                 TypeName::StrReference(
828                                     None,
829                                     StringEncoding::Utf8,
830                                     StdlibOrDiplomat::Stdlib,
831                                 )
832                             } else {
833                                 TypeName::Box(Box::new(TypeName::from_syn(tpe, self_path_type)))
834                             }
835                         } else {
836                             panic!("Expected first type argument for Box to be a type")
837                         }
838                     } else {
839                         panic!("Expected angle brackets for Box type")
840                     }
841                 } else if p_len == 1 && p.path.segments[0].ident == "Option"
842                     || is_runtime_type(p, "DiplomatOption")
843                 {
844                     if let syn::PathArguments::AngleBracketed(type_args) =
845                         &p.path.segments[0].arguments
846                     {
847                         if let syn::GenericArgument::Type(tpe) = &type_args.args[0] {
848                             let stdlib = if p.path.segments[0].ident == "Option" {
849                                 StdlibOrDiplomat::Stdlib
850                             } else {
851                                 StdlibOrDiplomat::Diplomat
852                             };
853                             TypeName::Option(
854                                 Box::new(TypeName::from_syn(tpe, self_path_type)),
855                                 stdlib,
856                             )
857                         } else {
858                             panic!("Expected first type argument for Option to be a type")
859                         }
860                     } else {
861                         panic!("Expected angle brackets for Option type")
862                     }
863                 } else if p_len == 1 && p.path.segments[0].ident == "Self" {
864                     if let Some(self_path_type) = self_path_type {
865                         TypeName::SelfType(self_path_type)
866                     } else {
867                         panic!("Cannot have `Self` type outside of a method");
868                     }
869                 } else if is_runtime_type(p, "DiplomatOwnedStrSlice")
870                     || is_runtime_type(p, "DiplomatOwnedStr16Slice")
871                     || is_runtime_type(p, "DiplomatOwnedUTF8StrSlice")
872                 {
873                     let encoding = if is_runtime_type(p, "DiplomatOwnedStrSlice") {
874                         StringEncoding::UnvalidatedUtf8
875                     } else if is_runtime_type(p, "DiplomatOwnedStr16Slice") {
876                         StringEncoding::UnvalidatedUtf16
877                     } else {
878                         StringEncoding::Utf8
879                     };
880 
881                     TypeName::StrReference(None, encoding, StdlibOrDiplomat::Diplomat)
882                 } else if is_runtime_type(p, "DiplomatStrSlice")
883                     || is_runtime_type(p, "DiplomatStr16Slice")
884                     || is_runtime_type(p, "DiplomatUtf8StrSlice")
885                 {
886                     let lt = get_lifetime_from_syn_path(p);
887 
888                     let encoding = if is_runtime_type(p, "DiplomatStrSlice") {
889                         StringEncoding::UnvalidatedUtf8
890                     } else if is_runtime_type(p, "DiplomatStr16Slice") {
891                         StringEncoding::UnvalidatedUtf16
892                     } else {
893                         StringEncoding::Utf8
894                     };
895 
896                     TypeName::StrReference(Some(lt), encoding, StdlibOrDiplomat::Diplomat)
897                 } else if is_runtime_type(p, "DiplomatSlice")
898                     || is_runtime_type(p, "DiplomatSliceMut")
899                     || is_runtime_type(p, "DiplomatOwnedSlice")
900                 {
901                     let ltmut = if is_runtime_type(p, "DiplomatOwnedSlice") {
902                         let mutability = if is_runtime_type(p, "DiplomatSlice") {
903                             Mutability::Immutable
904                         } else {
905                             Mutability::Mutable
906                         };
907                         let lt = get_lifetime_from_syn_path(p);
908                         Some((lt, mutability))
909                     } else {
910                         None
911                     };
912 
913                     let ty = get_ty_from_syn_path(p).expect("Expected type argument to DiplomatSlice/DiplomatSliceMut/DiplomatOwnedSlice");
914 
915                     if let syn::Type::Path(p) = &ty {
916                         if let Some(ident) = p.path.get_ident() {
917                             let ident = ident.to_string();
918                             let i = ident.as_str();
919                             match i {
920                                 "DiplomatStrSlice" => {
921                                     return TypeName::StrSlice(
922                                         StringEncoding::UnvalidatedUtf8,
923                                         StdlibOrDiplomat::Diplomat,
924                                     )
925                                 }
926                                 "DiplomatStr16Slice" => {
927                                     return TypeName::StrSlice(
928                                         StringEncoding::UnvalidatedUtf16,
929                                         StdlibOrDiplomat::Diplomat,
930                                     )
931                                 }
932                                 "DiplomatUtf8StrSlice" => {
933                                     return TypeName::StrSlice(
934                                         StringEncoding::Utf8,
935                                         StdlibOrDiplomat::Diplomat,
936                                     )
937                                 }
938                                 _ => {
939                                     if let Ok(prim) = PrimitiveType::from_str(i) {
940                                         return TypeName::PrimitiveSlice(
941                                             ltmut,
942                                             prim,
943                                             StdlibOrDiplomat::Diplomat,
944                                         );
945                                     }
946                                 }
947                             }
948                         }
949                     }
950                     panic!("Found DiplomatSlice/DiplomatSliceMut/DiplomatOwnedSlice without primitive or DiplomatStrSlice-like generic");
951                 } else if p_len == 1 && p.path.segments[0].ident == "Result"
952                     || is_runtime_type(p, "DiplomatResult")
953                 {
954                     if let syn::PathArguments::AngleBracketed(type_args) =
955                         &p.path.segments.last().unwrap().arguments
956                     {
957                         assert!(
958                             type_args.args.len() > 1,
959                             "Not enough arguments given to Result<T,E>. Are you using a non-std Result type?"
960                         );
961 
962                         if let (syn::GenericArgument::Type(ok), syn::GenericArgument::Type(err)) =
963                             (&type_args.args[0], &type_args.args[1])
964                         {
965                             let ok = TypeName::from_syn(ok, self_path_type.clone());
966                             let err = TypeName::from_syn(err, self_path_type);
967                             TypeName::Result(
968                                 Box::new(ok),
969                                 Box::new(err),
970                                 if is_runtime_type(p, "DiplomatResult") {
971                                     StdlibOrDiplomat::Diplomat
972                                 } else {
973                                     StdlibOrDiplomat::Stdlib
974                                 },
975                             )
976                         } else {
977                             panic!("Expected both type arguments for Result to be a type")
978                         }
979                     } else {
980                         panic!("Expected angle brackets for Result type")
981                     }
982                 } else if is_runtime_type(p, "DiplomatWrite") {
983                     TypeName::Write
984                 } else {
985                     TypeName::Named(PathType::from(p))
986                 }
987             }
988             syn::Type::Tuple(tup) => {
989                 if tup.elems.is_empty() {
990                     TypeName::Unit
991                 } else {
992                     todo!("Tuples are not currently supported")
993                 }
994             }
995             syn::Type::ImplTrait(tr) => {
996                 let trait_bound = tr.bounds.first();
997                 if tr.bounds.len() > 1 {
998                     todo!("Currently don't support implementing multiple traits");
999                 }
1000                 if let Some(syn::TypeParamBound::Trait(syn::TraitBound { path: p, .. })) =
1001                     trait_bound
1002                 {
1003                     let rel_segs = &p.segments;
1004                     let path_seg = &rel_segs[0];
1005                     if path_seg.ident.eq("Fn") {
1006                         // we're in a function type
1007                         // get input and output args
1008                         if let syn::PathArguments::Parenthesized(
1009                             syn::ParenthesizedGenericArguments {
1010                                 inputs: input_types,
1011                                 output: output_type,
1012                                 ..
1013                             },
1014                         ) = &path_seg.arguments
1015                         {
1016                             let in_types = input_types
1017                                 .iter()
1018                                 .map(|in_ty| {
1019                                     Box::new(TypeName::from_syn(in_ty, self_path_type.clone()))
1020                                 })
1021                                 .collect::<Vec<Box<TypeName>>>();
1022                             let out_type = match output_type {
1023                                 syn::ReturnType::Type(_, output_type) => {
1024                                     TypeName::from_syn(output_type, self_path_type.clone())
1025                                 }
1026                                 syn::ReturnType::Default => TypeName::Unit,
1027                             };
1028                             let ret = TypeName::Function(in_types, Box::new(out_type));
1029                             return ret;
1030                         }
1031                         panic!("Unsupported function type: {:?}", &path_seg.arguments);
1032                     } else {
1033                         let ret = TypeName::ImplTrait(PathType::from(&syn::TraitBound {
1034                             paren_token: None,
1035                             modifier: syn::TraitBoundModifier::None,
1036                             lifetimes: None, // todo this is an assumption
1037                             path: p.clone(),
1038                         }));
1039                         return ret;
1040                     }
1041                 }
1042                 panic!("Unsupported trait type: {:?}", tr);
1043             }
1044             other => panic!("Unsupported type: {}", other.to_token_stream()),
1045         }
1046     }
1047 
1048     /// Returns `true` if `self` is the `TypeName::SelfType` variant, otherwise
1049     /// `false`.
is_self(&self) -> bool1050     pub fn is_self(&self) -> bool {
1051         matches!(self, TypeName::SelfType(_))
1052     }
1053 
1054     /// Recurse down the type tree, visiting all lifetimes.
1055     ///
1056     /// Using this function, you can collect all the lifetimes into a collection,
1057     /// or examine each one without having to make any additional allocations.
visit_lifetimes<'a, F, B>(&'a self, visit: &mut F) -> ControlFlow<B> where F: FnMut(&'a Lifetime, LifetimeOrigin) -> ControlFlow<B>,1058     pub fn visit_lifetimes<'a, F, B>(&'a self, visit: &mut F) -> ControlFlow<B>
1059     where
1060         F: FnMut(&'a Lifetime, LifetimeOrigin) -> ControlFlow<B>,
1061     {
1062         match self {
1063             TypeName::Named(path_type) | TypeName::SelfType(path_type) => path_type
1064                 .lifetimes
1065                 .iter()
1066                 .try_for_each(|lt| visit(lt, LifetimeOrigin::Named)),
1067             TypeName::Reference(lt, _, ty) => {
1068                 ty.visit_lifetimes(visit)?;
1069                 visit(lt, LifetimeOrigin::Reference)
1070             }
1071             TypeName::Box(ty) | TypeName::Option(ty, _) => ty.visit_lifetimes(visit),
1072             TypeName::Result(ok, err, _) => {
1073                 ok.visit_lifetimes(visit)?;
1074                 err.visit_lifetimes(visit)
1075             }
1076             TypeName::StrReference(Some(lt), ..) => visit(lt, LifetimeOrigin::StrReference),
1077             TypeName::PrimitiveSlice(Some((lt, _)), ..) => {
1078                 visit(lt, LifetimeOrigin::PrimitiveSlice)
1079             }
1080             _ => ControlFlow::Continue(()),
1081         }
1082     }
1083 
1084     /// Returns `true` if any lifetime satisfies a predicate, otherwise `false`.
1085     ///
1086     /// This method is short-circuiting, meaning that if the predicate ever succeeds,
1087     /// it will return immediately.
any_lifetime<'a, F>(&'a self, mut f: F) -> bool where F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,1088     pub fn any_lifetime<'a, F>(&'a self, mut f: F) -> bool
1089     where
1090         F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,
1091     {
1092         self.visit_lifetimes(&mut |lifetime, origin| {
1093             if f(lifetime, origin) {
1094                 ControlFlow::Break(())
1095             } else {
1096                 ControlFlow::Continue(())
1097             }
1098         })
1099         .is_break()
1100     }
1101 
1102     /// Returns `true` if all lifetimes satisfy a predicate, otherwise `false`.
1103     ///
1104     /// This method is short-circuiting, meaning that if the predicate ever fails,
1105     /// it will return immediately.
all_lifetimes<'a, F>(&'a self, mut f: F) -> bool where F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,1106     pub fn all_lifetimes<'a, F>(&'a self, mut f: F) -> bool
1107     where
1108         F: FnMut(&'a Lifetime, LifetimeOrigin) -> bool,
1109     {
1110         self.visit_lifetimes(&mut |lifetime, origin| {
1111             if f(lifetime, origin) {
1112                 ControlFlow::Continue(())
1113             } else {
1114                 ControlFlow::Break(())
1115             }
1116         })
1117         .is_continue()
1118     }
1119 
1120     /// Returns all lifetimes in a [`LifetimeEnv`] that must live at least as
1121     /// long as the type.
longer_lifetimes<'env>( &self, lifetime_env: &'env LifetimeEnv, ) -> Vec<&'env NamedLifetime>1122     pub fn longer_lifetimes<'env>(
1123         &self,
1124         lifetime_env: &'env LifetimeEnv,
1125     ) -> Vec<&'env NamedLifetime> {
1126         self.transitive_lifetime_bounds(LifetimeTransitivity::longer(lifetime_env))
1127     }
1128 
1129     /// Returns all lifetimes in a [`LifetimeEnv`] that are outlived by the type.
shorter_lifetimes<'env>( &self, lifetime_env: &'env LifetimeEnv, ) -> Vec<&'env NamedLifetime>1130     pub fn shorter_lifetimes<'env>(
1131         &self,
1132         lifetime_env: &'env LifetimeEnv,
1133     ) -> Vec<&'env NamedLifetime> {
1134         self.transitive_lifetime_bounds(LifetimeTransitivity::shorter(lifetime_env))
1135     }
1136 
1137     /// Visits the provided [`LifetimeTransitivity`] value with all `NamedLifetime`s
1138     /// in the type tree, and returns the transitively reachable lifetimes.
transitive_lifetime_bounds<'env>( &self, mut transitivity: LifetimeTransitivity<'env>, ) -> Vec<&'env NamedLifetime>1139     fn transitive_lifetime_bounds<'env>(
1140         &self,
1141         mut transitivity: LifetimeTransitivity<'env>,
1142     ) -> Vec<&'env NamedLifetime> {
1143         self.visit_lifetimes(&mut |lifetime, _| -> ControlFlow<()> {
1144             if let Lifetime::Named(named) = lifetime {
1145                 transitivity.visit(named);
1146             }
1147             ControlFlow::Continue(())
1148         });
1149         transitivity.finish()
1150     }
1151 
is_zst(&self) -> bool1152     pub fn is_zst(&self) -> bool {
1153         // check_zst() prevents non-unit types from being ZSTs
1154         matches!(*self, TypeName::Unit)
1155     }
1156 
is_pointer(&self) -> bool1157     pub fn is_pointer(&self) -> bool {
1158         matches!(*self, TypeName::Reference(..) | TypeName::Box(_))
1159     }
1160 }
1161 
1162 #[non_exhaustive]
1163 pub enum LifetimeOrigin {
1164     Named,
1165     Reference,
1166     StrReference,
1167     PrimitiveSlice,
1168 }
1169 
is_runtime_type(p: &syn::TypePath, name: &str) -> bool1170 fn is_runtime_type(p: &syn::TypePath, name: &str) -> bool {
1171     (p.path.segments.len() == 1 && p.path.segments[0].ident == name)
1172         || (p.path.segments.len() == 2
1173             && p.path.segments[0].ident == "diplomat_runtime"
1174             && p.path.segments[1].ident == name)
1175 }
1176 
1177 impl fmt::Display for TypeName {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1178     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1179         match self {
1180             TypeName::Primitive(p) => p.fmt(f),
1181             TypeName::Ordering => write!(f, "Ordering"),
1182             TypeName::Named(p) | TypeName::SelfType(p) => p.fmt(f),
1183             TypeName::Reference(lifetime, mutability, typ) => {
1184                 write!(f, "{}{typ}", ReferenceDisplay(lifetime, mutability))
1185             }
1186             TypeName::Box(typ) => write!(f, "Box<{typ}>"),
1187             TypeName::Option(typ, StdlibOrDiplomat::Stdlib) => write!(f, "Option<{typ}>"),
1188             TypeName::Option(typ, StdlibOrDiplomat::Diplomat) => write!(f, "DiplomatOption<{typ}>"),
1189             TypeName::Result(ok, err, _) => {
1190                 write!(f, "Result<{ok}, {err}>")
1191             }
1192             TypeName::Write => "DiplomatWrite".fmt(f),
1193             TypeName::StrReference(lt, encoding, is_stdlib_type) => {
1194                 if let Some(lt) = lt {
1195                     if *is_stdlib_type == StdlibOrDiplomat::Stdlib {
1196                         let lt = ReferenceDisplay(lt, &Mutability::Immutable);
1197                         let ty = encoding.get_diplomat_slice_type_str();
1198                         write!(f, "{lt}{ty}")
1199                     } else {
1200                         let ty = encoding.get_stdlib_slice_type_str();
1201                         let lt = LifetimeGenericsListDisplay(lt);
1202                         write!(f, "{ty}{lt}")
1203                     }
1204                 } else {
1205                     match (encoding, is_stdlib_type) {
1206                         (_, StdlibOrDiplomat::Stdlib) => {
1207                             write!(f, "Box<{}>", encoding.get_diplomat_slice_type_str())
1208                         }
1209                         (StringEncoding::Utf8, StdlibOrDiplomat::Diplomat) => {
1210                             "DiplomatOwnedUtf8Str".fmt(f)
1211                         }
1212                         (StringEncoding::UnvalidatedUtf8, StdlibOrDiplomat::Diplomat) => {
1213                             "DiplomatOwnedStrSlice".fmt(f)
1214                         }
1215                         (StringEncoding::UnvalidatedUtf16, StdlibOrDiplomat::Diplomat) => {
1216                             "DiplomatOwnedStr16Slice".fmt(f)
1217                         }
1218                     }
1219                 }
1220             }
1221 
1222             TypeName::StrSlice(encoding, StdlibOrDiplomat::Stdlib) => {
1223                 let inner = encoding.get_stdlib_slice_type_str();
1224 
1225                 write!(f, "&[&{inner}]")
1226             }
1227             TypeName::StrSlice(encoding, StdlibOrDiplomat::Diplomat) => {
1228                 let inner = encoding.get_diplomat_slice_type_str();
1229                 write!(f, "DiplomatSlice<{inner}>")
1230             }
1231 
1232             TypeName::PrimitiveSlice(
1233                 Some((lifetime, mutability)),
1234                 typ,
1235                 StdlibOrDiplomat::Stdlib,
1236             ) => {
1237                 write!(f, "{}[{typ}]", ReferenceDisplay(lifetime, mutability))
1238             }
1239             TypeName::PrimitiveSlice(
1240                 Some((lifetime, mutability)),
1241                 typ,
1242                 StdlibOrDiplomat::Diplomat,
1243             ) => {
1244                 let maybemut = if *mutability == Mutability::Immutable {
1245                     ""
1246                 } else {
1247                     "Mut"
1248                 };
1249                 let lt = LifetimeGenericsListPartialDisplay(lifetime);
1250                 write!(f, "DiplomatSlice{maybemut}<{lt}{typ}>")
1251             }
1252             TypeName::PrimitiveSlice(None, typ, _) => write!(f, "Box<[{typ}]>"),
1253             TypeName::Unit => "()".fmt(f),
1254             TypeName::Function(input_types, out_type) => {
1255                 write!(f, "fn (")?;
1256                 for in_typ in input_types.iter() {
1257                     write!(f, "{in_typ}")?;
1258                 }
1259                 write!(f, ")->{out_type}")
1260             }
1261             TypeName::ImplTrait(trt) => {
1262                 write!(f, "impl ")?;
1263                 trt.fmt(f)
1264             }
1265         }
1266     }
1267 }
1268 
1269 /// An [`fmt::Display`] type for formatting Rust references.
1270 ///
1271 /// # Examples
1272 ///
1273 /// ```ignore
1274 /// let lifetime = Lifetime::from(&syn::parse_str::<syn::Lifetime>("'a"));
1275 /// let mutability = Mutability::Mutable;
1276 /// // ...
1277 /// let fmt = format!("{}[u8]", ReferenceDisplay(&lifetime, &mutability));
1278 ///
1279 /// assert_eq!(fmt, "&'a mut [u8]");
1280 /// ```
1281 struct ReferenceDisplay<'a>(&'a Lifetime, &'a Mutability);
1282 
1283 impl<'a> fmt::Display for ReferenceDisplay<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1284     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1285         match self.0 {
1286             Lifetime::Static => "&'static ".fmt(f)?,
1287             Lifetime::Named(lifetime) => write!(f, "&{lifetime} ")?,
1288             Lifetime::Anonymous => '&'.fmt(f)?,
1289         }
1290 
1291         if self.1.is_mutable() {
1292             "mut ".fmt(f)?;
1293         }
1294 
1295         Ok(())
1296     }
1297 }
1298 
1299 impl<'a> quote::ToTokens for ReferenceDisplay<'a> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)1300     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
1301         let lifetime = self.0.to_syn();
1302         let mutability = self.1.to_syn();
1303 
1304         tokens.append_all(quote::quote!(& #lifetime #mutability))
1305     }
1306 }
1307 
1308 /// An [`fmt::Display`] type for formatting Rust lifetimes as they show up in generics list, when
1309 /// the generics list has no other elements
1310 struct LifetimeGenericsListDisplay<'a>(&'a Lifetime);
1311 
1312 impl<'a> fmt::Display for LifetimeGenericsListDisplay<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1313     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1314         match self.0 {
1315             Lifetime::Static => "<'static>".fmt(f),
1316             Lifetime::Named(lifetime) => write!(f, "<{lifetime}>"),
1317             Lifetime::Anonymous => Ok(()),
1318         }
1319     }
1320 }
1321 
1322 impl<'a> quote::ToTokens for LifetimeGenericsListDisplay<'a> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)1323     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
1324         if let Lifetime::Anonymous = self.0 {
1325         } else {
1326             let lifetime = self.0.to_syn();
1327             tokens.append_all(quote::quote!(<#lifetime>))
1328         }
1329     }
1330 }
1331 
1332 /// An [`fmt::Display`] type for formatting Rust lifetimes as they show up in generics list, when
1333 /// the generics list has another element
1334 struct LifetimeGenericsListPartialDisplay<'a>(&'a Lifetime);
1335 
1336 impl<'a> fmt::Display for LifetimeGenericsListPartialDisplay<'a> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1337     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1338         match self.0 {
1339             Lifetime::Static => "'static,".fmt(f),
1340             Lifetime::Named(lifetime) => write!(f, "{lifetime},"),
1341             Lifetime::Anonymous => Ok(()),
1342         }
1343     }
1344 }
1345 
1346 impl<'a> quote::ToTokens for LifetimeGenericsListPartialDisplay<'a> {
to_tokens(&self, tokens: &mut proc_macro2::TokenStream)1347     fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
1348         if let Lifetime::Anonymous = self.0 {
1349         } else {
1350             let lifetime = self.0.to_syn();
1351             tokens.append_all(quote::quote!(#lifetime,))
1352         }
1353     }
1354 }
1355 
1356 impl fmt::Display for PathType {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1357     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1358         self.path.fmt(f)?;
1359 
1360         if let Some((first, rest)) = self.lifetimes.split_first() {
1361             write!(f, "<{first}")?;
1362             for lifetime in rest {
1363                 write!(f, ", {lifetime}")?;
1364             }
1365             '>'.fmt(f)?;
1366         }
1367         Ok(())
1368     }
1369 }
1370 
1371 /// A built-in Rust primitive scalar type.
1372 #[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
1373 #[allow(non_camel_case_types)]
1374 #[allow(clippy::exhaustive_enums)] // there are only these (scalar types)
1375 pub enum PrimitiveType {
1376     i8,
1377     u8,
1378     i16,
1379     u16,
1380     i32,
1381     u32,
1382     i64,
1383     u64,
1384     i128,
1385     u128,
1386     isize,
1387     usize,
1388     f32,
1389     f64,
1390     bool,
1391     char,
1392     /// a primitive byte that is not meant to be interpreted numerically
1393     /// in languages that don't have fine-grained integer types
1394     byte,
1395 }
1396 
1397 impl PrimitiveType {
as_code_str(self) -> &'static str1398     fn as_code_str(self) -> &'static str {
1399         match self {
1400             PrimitiveType::i8 => "i8",
1401             PrimitiveType::u8 => "u8",
1402             PrimitiveType::i16 => "i16",
1403             PrimitiveType::u16 => "u16",
1404             PrimitiveType::i32 => "i32",
1405             PrimitiveType::u32 => "u32",
1406             PrimitiveType::i64 => "i64",
1407             PrimitiveType::u64 => "u64",
1408             PrimitiveType::i128 => "i128",
1409             PrimitiveType::u128 => "u128",
1410             PrimitiveType::isize => "isize",
1411             PrimitiveType::usize => "usize",
1412             PrimitiveType::f32 => "f32",
1413             PrimitiveType::f64 => "f64",
1414             PrimitiveType::bool => "bool",
1415             PrimitiveType::char => "DiplomatChar",
1416             PrimitiveType::byte => "DiplomatByte",
1417         }
1418     }
1419 
to_ident(self) -> proc_macro2::Ident1420     fn to_ident(self) -> proc_macro2::Ident {
1421         proc_macro2::Ident::new(self.as_code_str(), Span::call_site())
1422     }
1423 
1424     /// Get the type for a slice of this, as specified using Rust stdlib types
get_stdlib_slice_type(self, lt: &Option<(Lifetime, Mutability)>) -> syn::Type1425     pub fn get_stdlib_slice_type(self, lt: &Option<(Lifetime, Mutability)>) -> syn::Type {
1426         let primitive = self.to_ident();
1427 
1428         if let Some((ref lt, ref mtbl)) = lt {
1429             let reference = ReferenceDisplay(lt, mtbl);
1430             syn::parse_quote_spanned!(Span::call_site() => #reference [#primitive])
1431         } else {
1432             syn::parse_quote_spanned!(Span::call_site() => Box<[#primitive]>)
1433         }
1434     }
1435 
1436     /// Get the type for a slice of this, as specified using Diplomat runtime types
get_diplomat_slice_type(self, lt: &Option<(Lifetime, Mutability)>) -> syn::Type1437     pub fn get_diplomat_slice_type(self, lt: &Option<(Lifetime, Mutability)>) -> syn::Type {
1438         let primitive = self.to_ident();
1439 
1440         if let Some((lt, mtbl)) = lt {
1441             let lifetime = LifetimeGenericsListPartialDisplay(lt);
1442 
1443             if *mtbl == Mutability::Immutable {
1444                 syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatSlice<#lifetime #primitive>)
1445             } else {
1446                 syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatSliceMut<#lifetime #primitive>)
1447             }
1448         } else {
1449             syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatOwnedSlice<#primitive>)
1450         }
1451     }
1452 }
1453 
1454 impl fmt::Display for PrimitiveType {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result1455     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1456         match self {
1457             PrimitiveType::byte => "u8",
1458             PrimitiveType::char => "char",
1459             _ => self.as_code_str(),
1460         }
1461         .fmt(f)
1462     }
1463 }
1464 
1465 impl FromStr for PrimitiveType {
1466     type Err = ();
from_str(s: &str) -> Result<Self, ()>1467     fn from_str(s: &str) -> Result<Self, ()> {
1468         Ok(match s {
1469             "i8" => PrimitiveType::i8,
1470             "u8" => PrimitiveType::u8,
1471             "i16" => PrimitiveType::i16,
1472             "u16" => PrimitiveType::u16,
1473             "i32" => PrimitiveType::i32,
1474             "u32" => PrimitiveType::u32,
1475             "i64" => PrimitiveType::i64,
1476             "u64" => PrimitiveType::u64,
1477             "i128" => PrimitiveType::i128,
1478             "u128" => PrimitiveType::u128,
1479             "isize" => PrimitiveType::isize,
1480             "usize" => PrimitiveType::usize,
1481             "f32" => PrimitiveType::f32,
1482             "f64" => PrimitiveType::f64,
1483             "bool" => PrimitiveType::bool,
1484             "DiplomatChar" => PrimitiveType::char,
1485             "DiplomatByte" => PrimitiveType::byte,
1486             _ => return Err(()),
1487         })
1488     }
1489 }
1490 
1491 #[cfg(test)]
1492 mod tests {
1493     use insta;
1494 
1495     use syn;
1496 
1497     use super::TypeName;
1498 
1499     #[test]
typename_primitives()1500     fn typename_primitives() {
1501         insta::assert_yaml_snapshot!(TypeName::from_syn(
1502             &syn::parse_quote! {
1503                 i32
1504             },
1505             None
1506         ));
1507 
1508         insta::assert_yaml_snapshot!(TypeName::from_syn(
1509             &syn::parse_quote! {
1510                 usize
1511             },
1512             None
1513         ));
1514 
1515         insta::assert_yaml_snapshot!(TypeName::from_syn(
1516             &syn::parse_quote! {
1517                 bool
1518             },
1519             None
1520         ));
1521     }
1522 
1523     #[test]
typename_named()1524     fn typename_named() {
1525         insta::assert_yaml_snapshot!(TypeName::from_syn(
1526             &syn::parse_quote! {
1527                 MyLocalStruct
1528             },
1529             None
1530         ));
1531     }
1532 
1533     #[test]
typename_references()1534     fn typename_references() {
1535         insta::assert_yaml_snapshot!(TypeName::from_syn(
1536             &syn::parse_quote! {
1537                 &i32
1538             },
1539             None
1540         ));
1541 
1542         insta::assert_yaml_snapshot!(TypeName::from_syn(
1543             &syn::parse_quote! {
1544                 &mut MyLocalStruct
1545             },
1546             None
1547         ));
1548     }
1549 
1550     #[test]
typename_boxes()1551     fn typename_boxes() {
1552         insta::assert_yaml_snapshot!(TypeName::from_syn(
1553             &syn::parse_quote! {
1554                 Box<i32>
1555             },
1556             None
1557         ));
1558 
1559         insta::assert_yaml_snapshot!(TypeName::from_syn(
1560             &syn::parse_quote! {
1561                 Box<MyLocalStruct>
1562             },
1563             None
1564         ));
1565     }
1566 
1567     #[test]
typename_option()1568     fn typename_option() {
1569         insta::assert_yaml_snapshot!(TypeName::from_syn(
1570             &syn::parse_quote! {
1571                 Option<i32>
1572             },
1573             None
1574         ));
1575 
1576         insta::assert_yaml_snapshot!(TypeName::from_syn(
1577             &syn::parse_quote! {
1578                 Option<MyLocalStruct>
1579             },
1580             None
1581         ));
1582     }
1583 
1584     #[test]
typename_result()1585     fn typename_result() {
1586         insta::assert_yaml_snapshot!(TypeName::from_syn(
1587             &syn::parse_quote! {
1588                 DiplomatResult<MyLocalStruct, i32>
1589             },
1590             None
1591         ));
1592 
1593         insta::assert_yaml_snapshot!(TypeName::from_syn(
1594             &syn::parse_quote! {
1595                 DiplomatResult<(), MyLocalStruct>
1596             },
1597             None
1598         ));
1599 
1600         insta::assert_yaml_snapshot!(TypeName::from_syn(
1601             &syn::parse_quote! {
1602                 Result<MyLocalStruct, i32>
1603             },
1604             None
1605         ));
1606 
1607         insta::assert_yaml_snapshot!(TypeName::from_syn(
1608             &syn::parse_quote! {
1609                 Result<(), MyLocalStruct>
1610             },
1611             None
1612         ));
1613     }
1614 
1615     #[test]
lifetimes()1616     fn lifetimes() {
1617         insta::assert_yaml_snapshot!(TypeName::from_syn(
1618             &syn::parse_quote! {
1619                 Foo<'a, 'b>
1620             },
1621             None
1622         ));
1623 
1624         insta::assert_yaml_snapshot!(TypeName::from_syn(
1625             &syn::parse_quote! {
1626                 ::core::my_type::Foo
1627             },
1628             None
1629         ));
1630 
1631         insta::assert_yaml_snapshot!(TypeName::from_syn(
1632             &syn::parse_quote! {
1633                 ::core::my_type::Foo<'test>
1634             },
1635             None
1636         ));
1637 
1638         insta::assert_yaml_snapshot!(TypeName::from_syn(
1639             &syn::parse_quote! {
1640                 Option<Ref<'object>>
1641             },
1642             None
1643         ));
1644 
1645         insta::assert_yaml_snapshot!(TypeName::from_syn(
1646             &syn::parse_quote! {
1647                 Foo<'a, 'b, 'c, 'd>
1648             },
1649             None
1650         ));
1651 
1652         insta::assert_yaml_snapshot!(TypeName::from_syn(
1653             &syn::parse_quote! {
1654                 very::long::path::to::my::Type<'x, 'y, 'z>
1655             },
1656             None
1657         ));
1658 
1659         insta::assert_yaml_snapshot!(TypeName::from_syn(
1660             &syn::parse_quote! {
1661                 Result<OkRef<'a, 'b>, ErrRef<'c>>
1662             },
1663             None
1664         ));
1665     }
1666 }
1667