• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use proc_macro2::{Span, TokenStream};
2 use quote::{quote, ToTokens};
3 use syn::*;
4 
5 use diplomat_core::ast::{self, StdlibOrDiplomat};
6 
7 mod enum_convert;
8 mod transparent_convert;
9 
cfgs_to_stream(attrs: &[Attribute]) -> proc_macro2::TokenStream10 fn cfgs_to_stream(attrs: &[Attribute]) -> proc_macro2::TokenStream {
11     attrs
12         .iter()
13         .fold(quote!(), |prev, attr| quote!(#prev #attr))
14 }
15 
param_ty(param_ty: &ast::TypeName) -> syn::Type16 fn param_ty(param_ty: &ast::TypeName) -> syn::Type {
17     match &param_ty {
18         ast::TypeName::StrReference(lt @ Some(_lt), encoding, _) => {
19             // At the param boundary we MUST use FFI-safe diplomat slice types,
20             // not Rust stdlib types (which are not FFI-safe and must be converted)
21             encoding.get_diplomat_slice_type(lt)
22         }
23         ast::TypeName::StrReference(None, encoding, _) => encoding.get_diplomat_slice_type(&None),
24         ast::TypeName::StrSlice(encoding, _) => {
25             // At the param boundary we MUST use FFI-safe diplomat slice types,
26             // not Rust stdlib types (which are not FFI-safe and must be converted)
27             let inner = encoding.get_diplomat_slice_type(&Some(ast::Lifetime::Anonymous));
28             syn::parse_quote_spanned!(Span::call_site() => diplomat_runtime::DiplomatSlice<#inner>)
29         }
30         ast::TypeName::PrimitiveSlice(ltmt, prim, _) => {
31             // At the param boundary we MUST use FFI-safe diplomat slice types,
32             // not Rust stdlib types (which are not FFI-safe and must be converted)
33             prim.get_diplomat_slice_type(ltmt)
34         }
35         ast::TypeName::Option(..) if !param_ty.is_ffi_safe() => {
36             param_ty.ffi_safe_version().to_syn()
37         }
38         _ => param_ty.to_syn(),
39     }
40 }
41 
param_conversion( name: &ast::Ident, param_type: &ast::TypeName, cast_to: Option<&syn::Type>, ) -> Option<proc_macro2::TokenStream>42 fn param_conversion(
43     name: &ast::Ident,
44     param_type: &ast::TypeName,
45     cast_to: Option<&syn::Type>,
46 ) -> Option<proc_macro2::TokenStream> {
47     match &param_type {
48         // conversion only needed for slices that are specified as Rust types rather than diplomat_runtime types
49         ast::TypeName::StrReference(.., StdlibOrDiplomat::Stdlib)
50         | ast::TypeName::StrSlice(.., StdlibOrDiplomat::Stdlib)
51         | ast::TypeName::PrimitiveSlice(.., StdlibOrDiplomat::Stdlib)
52         | ast::TypeName::Result(..) => Some(if let Some(cast_to) = cast_to {
53             quote!(let #name: #cast_to = #name.into();)
54         } else {
55             quote!(let #name = #name.into();)
56         }),
57         // Convert Option<struct/enum/primitive> and DiplomatOption<opaque>
58         // simplify the check by just checking is_ffi_safe()
59         ast::TypeName::Option(inner, _stdlib) => {
60             let mut tokens = TokenStream::new();
61 
62             if !param_type.is_ffi_safe() {
63                 let inner_ty = inner.ffi_safe_version().to_syn();
64                 tokens.extend(quote!(let #name : Option<#inner_ty> = #name.into();));
65             }
66             if !inner.is_ffi_safe() {
67                 tokens.extend(quote!(let #name = #name.map(|v| v.into());));
68             }
69 
70             if !tokens.is_empty() {
71                 Some(tokens)
72             } else {
73                 None
74             }
75         }
76         ast::TypeName::Function(in_types, out_type) => {
77             let cb_wrap_ident = &name;
78             let mut cb_param_list = vec![];
79             let mut cb_params_and_types_list = vec![];
80             let mut cb_arg_type_list = vec![];
81             let mut all_params_conversion = vec![];
82             for (index, in_ty) in in_types.iter().enumerate() {
83                 let param_ident_str = format!("arg{}", index);
84                 let orig_type = in_ty.to_syn();
85                 let param_converted_type = param_ty(in_ty);
86                 if let Some(conversion) = param_conversion(
87                     &ast::Ident::from(param_ident_str.clone()),
88                     in_ty,
89                     Some(&param_converted_type),
90                 ) {
91                     all_params_conversion.push(conversion);
92                 }
93                 let param_ident = Ident::new(&param_ident_str, Span::call_site());
94                 cb_arg_type_list.push(param_converted_type);
95                 cb_params_and_types_list.push(quote!(#param_ident: #orig_type));
96                 cb_param_list.push(param_ident);
97             }
98             let cb_ret_type = out_type.to_syn();
99 
100             let tokens = quote! {
101                 let #cb_wrap_ident = move | #(#cb_params_and_types_list,)* | unsafe {
102                     #(#all_params_conversion)*
103                     std::mem::transmute::<unsafe extern "C" fn (*const c_void, ...) -> #cb_ret_type, unsafe extern "C" fn (*const c_void, #(#cb_arg_type_list,)*) -> #cb_ret_type>
104                         (#cb_wrap_ident.run_callback)(#cb_wrap_ident.data, #(#cb_param_list,)*)
105                 };
106             };
107             Some(parse2(tokens).unwrap())
108         }
109         _ => None,
110     }
111 }
112 
gen_custom_vtable(custom_trait: &ast::Trait, custom_trait_vtable_type: &Ident) -> Item113 fn gen_custom_vtable(custom_trait: &ast::Trait, custom_trait_vtable_type: &Ident) -> Item {
114     let mut method_sigs: Vec<proc_macro2::TokenStream> = vec![];
115     method_sigs.push(quote!(
116         pub destructor: Option<unsafe extern "C" fn(*const c_void)>,
117         pub size: usize,
118         pub alignment: usize,
119     ));
120     for m in &custom_trait.methods {
121         // TODO check that this is the right conversion, it might be the wrong direction
122         let mut param_types: Vec<syn::Type> = m.params.iter().map(|p| param_ty(&p.ty)).collect();
123         let method_name = Ident::new(&format!("run_{}_callback", m.name), Span::call_site());
124         let return_tokens = match &m.output_type {
125             Some(ret_ty) => {
126                 let conv_ret_ty = ret_ty.to_syn();
127                 quote!( -> #conv_ret_ty)
128             }
129             None => {
130                 quote! {}
131             }
132         };
133         param_types.insert(0, syn::parse_quote!(*const c_void));
134         method_sigs.push(quote!(
135             pub #method_name: unsafe extern "C" fn (#(#param_types),*) #return_tokens,
136 
137         ));
138     }
139     syn::parse_quote!(
140         #[repr(C)]
141         pub struct #custom_trait_vtable_type {
142             #(#method_sigs)*
143         }
144     )
145 }
146 
gen_custom_trait_impl(custom_trait: &ast::Trait, custom_trait_struct_name: &Ident) -> Item147 fn gen_custom_trait_impl(custom_trait: &ast::Trait, custom_trait_struct_name: &Ident) -> Item {
148     let mut methods: Vec<Item> = vec![];
149     for m in &custom_trait.methods {
150         let param_names: Vec<proc_macro2::TokenStream> = m
151             .params
152             .iter()
153             .map(|p| {
154                 let p_name = &p.name;
155                 quote! {, #p_name}
156             })
157             .collect();
158         let mut all_params_conversion = vec![];
159         let mut param_names_and_types: Vec<proc_macro2::TokenStream> = m
160             .params
161             .iter()
162             .map(|p| {
163                 let orig_type = p.ty.to_syn();
164                 let p_ty = param_ty(&p.ty);
165                 if let Some(conversion) = param_conversion(&p.name.clone(), &p.ty, Some(&p_ty)) {
166                     all_params_conversion.push(conversion);
167                 }
168                 let p_name = &p.name;
169                 quote!(#p_name : #orig_type)
170             })
171             .collect();
172         let method_name = &m.name;
173         let (return_tokens, end_token) = match &m.output_type {
174             Some(ret_ty) => {
175                 let conv_ret_ty = ret_ty.to_syn();
176                 (quote!( -> #conv_ret_ty), quote! {})
177             }
178             None => (quote! {}, quote! {;}),
179         };
180         if let Some(self_param) = &m.self_param {
181             let mut self_modifier = quote! {};
182             if let Some((lifetime, mutability)) = &self_param.reference {
183                 let lifetime_mod = if *lifetime == ast::Lifetime::Anonymous {
184                     quote! { & }
185                 } else {
186                     let prime = "'".to_string();
187                     let lifetime = lifetime.to_syn();
188                     quote! { & #prime #lifetime }
189                 };
190                 let mutability_mod = if *mutability == ast::Mutability::Mutable {
191                     quote! {mut}
192                 } else {
193                     quote! {}
194                 };
195                 self_modifier = quote! { #lifetime_mod #mutability_mod }
196             }
197             param_names_and_types.insert(0, quote!(#self_modifier self));
198         }
199 
200         let lifetimes = {
201             let lifetime_env = &m.lifetimes;
202             if lifetime_env.is_empty() {
203                 quote! {}
204             } else {
205                 quote! { <#lifetime_env> }
206             }
207         };
208         let runner_method_name =
209             Ident::new(&format!("run_{}_callback", method_name), Span::call_site());
210         methods.push(syn::Item::Fn(syn::parse_quote!(
211             fn #method_name #lifetimes (#(#param_names_and_types),*) #return_tokens {
212                 unsafe {
213                     #(#all_params_conversion)*
214                     ((self.vtable).#runner_method_name)(self.data #(#param_names)*)#end_token
215                 }
216             }
217 
218         )));
219     }
220     let trait_name = &custom_trait.name;
221     syn::parse_quote!(
222         impl #trait_name for #custom_trait_struct_name {
223             #(#methods)*
224         }
225     )
226 }
227 
gen_custom_type_method(strct: &ast::CustomType, m: &ast::Method) -> Item228 fn gen_custom_type_method(strct: &ast::CustomType, m: &ast::Method) -> Item {
229     let self_ident = Ident::new(strct.name().as_str(), Span::call_site());
230     let method_ident = Ident::new(m.name.as_str(), Span::call_site());
231     let extern_ident = Ident::new(m.abi_name.as_str(), Span::call_site());
232 
233     let mut all_params = vec![];
234 
235     let mut all_params_conversion = vec![];
236     let mut all_params_names = vec![];
237     m.params.iter().for_each(|p| {
238         let ty = param_ty(&p.ty);
239         let name = &p.name;
240         all_params_names.push(name);
241         all_params.push(syn::parse_quote!(#name: #ty));
242         if let Some(conversion) = param_conversion(&p.name, &p.ty, None) {
243             all_params_conversion.push(conversion);
244         }
245     });
246 
247     let this_ident = Pat::Ident(PatIdent {
248         attrs: vec![],
249         by_ref: None,
250         mutability: None,
251         ident: Ident::new("this", Span::call_site()),
252         subpat: None,
253     });
254 
255     if let Some(self_param) = &m.self_param {
256         all_params.insert(
257             0,
258             FnArg::Typed(PatType {
259                 attrs: vec![],
260                 pat: Box::new(this_ident.clone()),
261                 colon_token: syn::token::Colon(Span::call_site()),
262                 ty: Box::new(self_param.to_typename().to_syn()),
263             }),
264         );
265     }
266 
267     let lifetimes = {
268         let lifetime_env = &m.lifetime_env;
269         if lifetime_env.is_empty() {
270             quote! {}
271         } else {
272             quote! { <#lifetime_env> }
273         }
274     };
275 
276     let method_invocation = if m.self_param.is_some() {
277         quote! { #this_ident.#method_ident }
278     } else {
279         quote! { #self_ident::#method_ident }
280     };
281 
282     let (return_tokens, maybe_into) = if let Some(return_type) = &m.return_type {
283         if let ast::TypeName::Result(ok, err, StdlibOrDiplomat::Stdlib) = return_type {
284             let ok = ok.to_syn();
285             let err = err.to_syn();
286             (
287                 quote! { -> diplomat_runtime::DiplomatResult<#ok, #err> },
288                 quote! { .into() },
289             )
290         } else if let ast::TypeName::StrReference(_, _, StdlibOrDiplomat::Stdlib)
291         | ast::TypeName::StrSlice(.., StdlibOrDiplomat::Stdlib)
292         | ast::TypeName::PrimitiveSlice(_, _, StdlibOrDiplomat::Stdlib) = return_type
293         {
294             let return_type_syn = return_type.ffi_safe_version().to_syn();
295             (quote! { -> #return_type_syn }, quote! { .into() })
296         } else if let ast::TypeName::Ordering = return_type {
297             let return_type_syn = return_type.to_syn();
298             (quote! { -> #return_type_syn }, quote! { as i8 })
299         } else if let ast::TypeName::Option(ty, is_std_option) = return_type {
300             match ty.as_ref() {
301                 // pass by reference, Option becomes null
302                 ast::TypeName::Box(..) | ast::TypeName::Reference(..) => {
303                     let return_type_syn = return_type.to_syn();
304                     let conversion = if *is_std_option == StdlibOrDiplomat::Stdlib {
305                         quote! {}
306                     } else {
307                         quote! {.into()}
308                     };
309                     (quote! { -> #return_type_syn }, conversion)
310                 }
311                 // anything else goes through DiplomatResult
312                 _ => {
313                     let ty = ty.to_syn();
314                     let conversion = if *is_std_option == StdlibOrDiplomat::Stdlib {
315                         quote! { .ok_or(()).into() }
316                     } else {
317                         quote! {}
318                     };
319                     (
320                         quote! { -> diplomat_runtime::DiplomatResult<#ty, ()> },
321                         conversion,
322                     )
323                 }
324             }
325         } else {
326             let return_type_syn = return_type.to_syn();
327             (quote! { -> #return_type_syn }, quote! {})
328         }
329     } else {
330         (quote! {}, quote! {})
331     };
332 
333     let write_flushes = m
334         .params
335         .iter()
336         .filter(|p| p.is_write())
337         .map(|p| {
338             let p = &p.name;
339             quote! { #p.flush(); }
340         })
341         .collect::<Vec<_>>();
342 
343     let cfg = cfgs_to_stream(&m.attrs.cfg);
344     if write_flushes.is_empty() {
345         Item::Fn(syn::parse_quote! {
346             #[no_mangle]
347             #cfg
348             extern "C" fn #extern_ident #lifetimes(#(#all_params),*) #return_tokens {
349                 #(#all_params_conversion)*
350                 #method_invocation(#(#all_params_names),*) #maybe_into
351             }
352         })
353     } else {
354         Item::Fn(syn::parse_quote! {
355             #[no_mangle]
356             #cfg
357             extern "C" fn #extern_ident #lifetimes(#(#all_params),*) #return_tokens {
358                 #(#all_params_conversion)*
359                 let ret = #method_invocation(#(#all_params_names),*);
360                 #(#write_flushes)*
361                 ret #maybe_into
362             }
363         })
364     }
365 }
366 
367 struct AttributeInfo {
368     repr: bool,
369     opaque: bool,
370     #[allow(unused)]
371     is_out: bool,
372 }
373 
374 impl AttributeInfo {
extract(attrs: &mut Vec<Attribute>) -> Self375     fn extract(attrs: &mut Vec<Attribute>) -> Self {
376         let mut repr = false;
377         let mut opaque = false;
378         let mut is_out = false;
379         attrs.retain(|attr| {
380             let ident = &attr.path().segments.iter().next().unwrap().ident;
381             if ident == "repr" {
382                 repr = true;
383                 // don't actually extract repr attrs, just detect them
384                 return true;
385             } else if ident == "diplomat" {
386                 if attr.path().segments.len() == 2 {
387                     let seg = &attr.path().segments.iter().nth(1).unwrap().ident;
388                     if seg == "opaque" {
389                         opaque = true;
390                         return false;
391                     } else if seg == "out" {
392                         is_out = true;
393                         return false;
394                     } else if seg == "rust_link"
395                         || seg == "out"
396                         || seg == "attr"
397                         || seg == "abi_rename"
398                         || seg == "demo"
399                     {
400                         // diplomat-tool reads these, not diplomat::bridge.
401                         // throw them away so rustc doesn't complain about unknown attributes
402                         return false;
403                     } else if seg == "enum_convert" || seg == "transparent_convert" {
404                         // diplomat::bridge doesn't read this, but it's handled separately
405                         // as an attribute
406                         return true;
407                     } else {
408                         panic!("Only #[diplomat::opaque] and #[diplomat::rust_link] are supported: {:?}", seg)
409                     }
410                 } else {
411                     panic!("#[diplomat::foo] attrs have a single-segment path name")
412                 }
413             }
414             true
415         });
416 
417         Self {
418             repr,
419             opaque,
420             is_out,
421         }
422     }
423 }
424 
gen_bridge(mut input: ItemMod) -> ItemMod425 fn gen_bridge(mut input: ItemMod) -> ItemMod {
426     let module = ast::Module::from_syn(&input, true);
427     // Clean out any diplomat attributes so Rust doesn't get mad
428     let _attrs = AttributeInfo::extract(&mut input.attrs);
429     let (brace, mut new_contents) = input.content.unwrap();
430 
431     new_contents.push(parse2(quote! { use diplomat_runtime::*; }).unwrap());
432     new_contents.push(parse2(quote! { use core::ffi::c_void; }).unwrap());
433 
434     new_contents.iter_mut().for_each(|c| match c {
435         Item::Struct(s) => {
436             let info = AttributeInfo::extract(&mut s.attrs);
437 
438             if !info.opaque {
439                 // This is validated by HIR, but it's also nice to validate it in the macro so that there
440                 // are early error messages
441                 for field in s.fields.iter_mut() {
442                     let _attrs = AttributeInfo::extract(&mut field.attrs);
443                     let ty = ast::TypeName::from_syn(&field.ty, None);
444                     if !ty.is_ffi_safe() {
445                         let ffisafe = ty.ffi_safe_version();
446                         panic!(
447                             "Found non-FFI safe type inside struct: {}, try {}",
448                             ty, ffisafe
449                         );
450                     }
451                 }
452             }
453 
454             // Normal opaque types don't need repr(transparent) because the inner type is
455             // never referenced. #[diplomat::transparent_convert] handles adding repr(transparent)
456             // on its own
457             if !info.opaque {
458                 let repr = if !info.repr {
459                     quote!(#[repr(C)])
460                 } else {
461                     quote!()
462                 };
463 
464                 *s = syn::parse_quote! {
465                     #repr
466                     #s
467                 }
468             }
469         }
470 
471         Item::Enum(e) => {
472             let info = AttributeInfo::extract(&mut e.attrs);
473 
474             for v in &mut e.variants {
475                 let info = AttributeInfo::extract(&mut v.attrs);
476                 if info.opaque {
477                     panic!("#[diplomat::opaque] not allowed on enum variants");
478                 }
479             }
480 
481             // Normal opaque types don't need repr(transparent) because the inner type is
482             // never referenced.
483             if !info.opaque {
484                 *e = syn::parse_quote! {
485                     #[repr(C)]
486                     #[derive(Clone, Copy)]
487                     #e
488                 };
489             }
490         }
491 
492         Item::Impl(i) => {
493             for item in &mut i.items {
494                 if let syn::ImplItem::Fn(ref mut m) = *item {
495                     let info = AttributeInfo::extract(&mut m.attrs);
496                     if info.opaque {
497                         panic!("#[diplomat::opaque] not allowed on methods")
498                     }
499                     for i in m.sig.inputs.iter_mut() {
500                         let _attrs = match i {
501                             syn::FnArg::Receiver(s) => AttributeInfo::extract(&mut s.attrs),
502                             syn::FnArg::Typed(t) => AttributeInfo::extract(&mut t.attrs),
503                         };
504                     }
505                 }
506             }
507         }
508         _ => (),
509     });
510 
511     for custom_type in module.declared_types.values() {
512         custom_type.methods().iter().for_each(|m| {
513             let gen_m = gen_custom_type_method(custom_type, m);
514             new_contents.push(gen_m);
515         });
516 
517         if let ast::CustomType::Opaque(opaque) = custom_type {
518             let destroy_ident = Ident::new(opaque.dtor_abi_name.as_str(), Span::call_site());
519 
520             let type_ident = custom_type.name().to_syn();
521 
522             let (lifetime_defs, lifetimes) = if let Some(lifetime_env) = custom_type.lifetimes() {
523                 (
524                     quote! { <#lifetime_env> },
525                     lifetime_env.lifetimes_to_tokens(),
526                 )
527             } else {
528                 (quote! {}, quote! {})
529             };
530 
531             let cfg = cfgs_to_stream(&custom_type.attrs().cfg);
532 
533             // for now, body is empty since all we need to do is drop the box
534             // TODO(#13): change to take a `*mut` and handle DST boxes appropriately
535             new_contents.push(Item::Fn(syn::parse_quote! {
536                 #[no_mangle]
537                 #cfg
538                 extern "C" fn #destroy_ident #lifetime_defs(this: Box<#type_ident #lifetimes>) {}
539             }));
540         }
541     }
542 
543     for custom_trait in module.declared_traits.values() {
544         let custom_trait_name = Ident::new(
545             &format!("DiplomatTraitStruct_{}", custom_trait.name),
546             Span::call_site(),
547         );
548         let custom_trait_vtable_type =
549             Ident::new(&format!("{}_VTable", custom_trait.name), Span::call_site());
550 
551         // vtable
552         new_contents.push(gen_custom_vtable(custom_trait, &custom_trait_vtable_type));
553 
554         // trait struct
555         new_contents.push(syn::parse_quote! {
556             #[repr(C)]
557             pub struct #custom_trait_name {
558                 data: *const c_void,
559                 pub vtable: #custom_trait_vtable_type,
560             }
561         });
562         if custom_trait.is_send {
563             new_contents.push(syn::parse_quote! {
564                 unsafe impl std::marker::Send for #custom_trait_name {}
565             });
566         }
567         if custom_trait.is_sync {
568             new_contents.push(syn::parse_quote! {
569                 unsafe impl std::marker::Sync for #custom_trait_name {}
570             });
571         }
572 
573         // trait struct wrapper for all methods
574         new_contents.push(gen_custom_trait_impl(custom_trait, &custom_trait_name));
575 
576         // destructor
577         new_contents.push(syn::parse_quote! {
578             impl Drop for #custom_trait_name {
579                 fn drop(&mut self) {
580                     if let Some(destructor) = self.vtable.destructor {
581                         unsafe {
582                             (destructor)(self.data);
583                         }
584                     }
585                 }
586             }
587         })
588     }
589 
590     ItemMod {
591         attrs: input.attrs,
592         vis: input.vis,
593         mod_token: input.mod_token,
594         ident: input.ident,
595         content: Some((brace, new_contents)),
596         semi: input.semi,
597         unsafety: None,
598     }
599 }
600 
601 /// Mark a module to be exposed through Diplomat-generated FFI.
602 #[proc_macro_attribute]
bridge( _attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream603 pub fn bridge(
604     _attr: proc_macro::TokenStream,
605     input: proc_macro::TokenStream,
606 ) -> proc_macro::TokenStream {
607     let expanded = gen_bridge(parse_macro_input!(input));
608     proc_macro::TokenStream::from(expanded.to_token_stream())
609 }
610 
611 /// Generate From and Into implementations for a Diplomat enum
612 ///
613 /// This is invoked as `#[diplomat::enum_convert(OtherEnumName)]`
614 /// on a Diplomat enum. It will assume the other enum has exactly the same variants
615 /// and generate From and Into implementations using those. In case that enum is `#[non_exhaustive]`,
616 /// you may use `#[diplomat::enum_convert(OtherEnumName, needs_wildcard)]` to generate a panicky wildcard
617 /// branch. It is up to the library author to ensure the enums are kept in sync. You may use the `#[non_exhaustive_omitted_patterns]`
618 /// lint to enforce this.
619 #[proc_macro_attribute]
enum_convert( attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream620 pub fn enum_convert(
621     attr: proc_macro::TokenStream,
622     input: proc_macro::TokenStream,
623 ) -> proc_macro::TokenStream {
624     // proc macros handle compile errors by using special error tokens.
625     // In case of an error, we don't want the original code to go away too
626     // (otherwise that will cause more errors) so we hold on to it and we tack it in
627     // with no modifications below
628     let input_cached: proc_macro2::TokenStream = input.clone().into();
629     let expanded =
630         enum_convert::gen_enum_convert(parse_macro_input!(attr), parse_macro_input!(input));
631 
632     let full = quote! {
633         #expanded
634         #input_cached
635     };
636     proc_macro::TokenStream::from(full.to_token_stream())
637 }
638 
639 /// Generate conversions from inner types for opaque Diplomat types with a single field
640 ///
641 /// This is invoked as `#[diplomat::transparent_convert]`
642 /// on an opaque Diplomat type. It will add `#[repr(transparent)]` and implement `pub(crate) fn transparent_convert()`
643 /// which allows constructing an `&Self` from a reference to the inner field.
644 #[proc_macro_attribute]
transparent_convert( _attr: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream645 pub fn transparent_convert(
646     _attr: proc_macro::TokenStream,
647     input: proc_macro::TokenStream,
648 ) -> proc_macro::TokenStream {
649     // proc macros handle compile errors by using special error tokens.
650     // In case of an error, we don't want the original code to go away too
651     // (otherwise that will cause more errors) so we hold on to it and we tack it in
652     // with no modifications below
653     let input_cached: proc_macro2::TokenStream = input.clone().into();
654     let expanded = transparent_convert::gen_transparent_convert(parse_macro_input!(input));
655 
656     let full = quote! {
657         #expanded
658         #[repr(transparent)]
659         #input_cached
660     };
661     proc_macro::TokenStream::from(full.to_token_stream())
662 }
663 
664 #[cfg(test)]
665 mod tests {
666     use std::fs::File;
667     use std::io::{Read, Write};
668     use std::process::Command;
669 
670     use quote::ToTokens;
671     use syn::parse_quote;
672     use tempfile::tempdir;
673 
674     use super::gen_bridge;
675 
rustfmt_code(code: &str) -> String676     fn rustfmt_code(code: &str) -> String {
677         let dir = tempdir().unwrap();
678         let file_path = dir.path().join("temp.rs");
679         let mut file = File::create(file_path.clone()).unwrap();
680 
681         writeln!(file, "{code}").unwrap();
682         drop(file);
683 
684         Command::new("rustfmt")
685             .arg(file_path.to_str().unwrap())
686             .spawn()
687             .unwrap()
688             .wait()
689             .unwrap();
690 
691         let mut file = File::open(file_path).unwrap();
692         let mut data = String::new();
693         file.read_to_string(&mut data).unwrap();
694         drop(file);
695         dir.close().unwrap();
696         data
697     }
698 
699     #[test]
method_taking_str()700     fn method_taking_str() {
701         insta::assert_snapshot!(rustfmt_code(
702             &gen_bridge(parse_quote! {
703                 mod ffi {
704                     struct Foo {}
705 
706                     impl Foo {
707                         pub fn from_str(s: &DiplomatStr) {
708                             unimplemented!()
709                         }
710                     }
711                 }
712             })
713             .to_token_stream()
714             .to_string()
715         ));
716     }
717 
718     #[test]
slices()719     fn slices() {
720         insta::assert_snapshot!(rustfmt_code(
721             &gen_bridge(parse_quote! {
722                 mod ffi {
723                     use diplomat_runtime::{DiplomatStr, DiplomatStr16, DiplomatByte, DiplomatOwnedSlice,
724                                            DiplomatOwnedStr16Slice, DiplomatOwnedStrSlice, DiplomatOwnedUTF8StrSlice,
725                                            DiplomatSlice, DiplomatSliceMut, DiplomatStr16Slice, DiplomatStrSlice, DiplomatUtf8StrSlice};
726                     struct Foo<'a> {
727                         a: DiplomatSlice<'a, u8>,
728                         b: DiplomatSlice<'a, u16>,
729                         c: DiplomatUtf8StrSlice<'a>,
730                         d: DiplomatStrSlice<'a>,
731                         e: DiplomatStr16Slice<'a>,
732                         f: DiplomatSlice<'a, DiplomatByte>,
733                     }
734 
735                     impl Foo {
736                         pub fn make(a: &'a [u8], b: &'a [u16], c: &'a str, d: &'a DiplomatStr, e: &'a DiplomatStr16, f: &'a [DiplomatByte]) -> Self {
737                             Foo {
738                                 a, b, c, d, e, f,
739                             }
740                         }
741                         pub fn make_runtime_types(a: DiplomatSlice<'a, u8>, b: DiplomatSlice<'a, u16>, c: DiplomatUtf8StrSlice<'a>, d: DiplomatStrSlice<'a>, e: DiplomatStr16Slice<'a>, f: DiplomatSlice<'a, DiplomatByte>) -> Self {
742                             Foo {
743                                 a: a.into(),
744                                 b: b.into(),
745                                 c: c.into(),
746                                 d: d.into(),
747                                 e: e.into(),
748                                 f: f.into(),
749                             }
750                         }
751                         pub fn boxes(a: Box<[u8]>, b: Box<[u16]>, c: Box<str>, d: Box<DiplomatStr>, e: Box<DiplomatStr16>, f: Box<[DiplomatByte]>) -> Self {
752                             unimplemented!()
753                         }
754                         pub fn boxes_runtime_types(a: DiplomatOwnedSlice<u8>, b: DiplomatOwnedSlice<u16>, c: DiplomatOwnedUTF8StrSlice, d: DiplomatOwnedStrSlice, e: DiplomatOwnedStr16Slice, f: DiplomatOwnedSlice<DiplomatByte>) -> Self {
755                             unimplemented!()
756                         }
757                         pub fn a(self) -> &[u8] {
758                             self.a
759                         }
760                         pub fn b(self) -> &[u16] {
761                             self.b
762                         }
763                         pub fn c(self) -> &str {
764                             self.c
765                         }
766                         pub fn d(self) -> &DiplomatStr {
767                             self.d
768                         }
769                         pub fn e(self) -> &DiplomatStr16 {
770                             self.e
771                         }
772                         pub fn f(self) -> &[DiplomatByte] {
773                             self.f
774                         }
775                     }
776                 }
777             })
778             .to_token_stream()
779             .to_string()
780         ));
781     }
782 
783     #[test]
method_taking_slice()784     fn method_taking_slice() {
785         insta::assert_snapshot!(rustfmt_code(
786             &gen_bridge(parse_quote! {
787                 mod ffi {
788                     struct Foo {}
789 
790                     impl Foo {
791                         pub fn from_slice(s: &[f64]) {
792                             unimplemented!()
793                         }
794                     }
795                 }
796             })
797             .to_token_stream()
798             .to_string()
799         ));
800     }
801 
802     #[test]
method_taking_mutable_slice()803     fn method_taking_mutable_slice() {
804         insta::assert_snapshot!(rustfmt_code(
805             &gen_bridge(parse_quote! {
806                 mod ffi {
807                     struct Foo {}
808 
809                     impl Foo {
810                         pub fn fill_slice(s: &mut [f64]) {
811                             unimplemented!()
812                         }
813                     }
814                 }
815             })
816             .to_token_stream()
817             .to_string()
818         ));
819     }
820 
821     #[test]
method_taking_owned_slice()822     fn method_taking_owned_slice() {
823         insta::assert_snapshot!(rustfmt_code(
824             &gen_bridge(parse_quote! {
825                 mod ffi {
826                     struct Foo {}
827 
828                     impl Foo {
829                         pub fn fill_slice(s: Box<[u16]>) {
830                             unimplemented!()
831                         }
832                     }
833                 }
834             })
835             .to_token_stream()
836             .to_string()
837         ));
838     }
839 
840     #[test]
method_taking_owned_str()841     fn method_taking_owned_str() {
842         insta::assert_snapshot!(rustfmt_code(
843             &gen_bridge(parse_quote! {
844                 mod ffi {
845                     struct Foo {}
846 
847                     impl Foo {
848                         pub fn something_with_str(s: Box<str>) {
849                             unimplemented!()
850                         }
851                     }
852                 }
853             })
854             .to_token_stream()
855             .to_string()
856         ));
857     }
858 
859     #[test]
mod_with_enum()860     fn mod_with_enum() {
861         insta::assert_snapshot!(rustfmt_code(
862             &gen_bridge(parse_quote! {
863                 mod ffi {
864                     enum Abc {
865                         A,
866                         B = 123,
867                     }
868 
869                     impl Abc {
870                         pub fn do_something(&self) {
871                             unimplemented!()
872                         }
873                     }
874                 }
875             })
876             .to_token_stream()
877             .to_string()
878         ));
879     }
880 
881     #[test]
mod_with_write_result()882     fn mod_with_write_result() {
883         insta::assert_snapshot!(rustfmt_code(
884             &gen_bridge(parse_quote! {
885                 mod ffi {
886                     struct Foo {}
887 
888                     impl Foo {
889                         pub fn to_string(&self, to: &mut DiplomatWrite) -> Result<(), ()> {
890                             unimplemented!()
891                         }
892                     }
893                 }
894             })
895             .to_token_stream()
896             .to_string()
897         ));
898     }
899 
900     #[test]
mod_with_rust_result()901     fn mod_with_rust_result() {
902         insta::assert_snapshot!(rustfmt_code(
903             &gen_bridge(parse_quote! {
904                 mod ffi {
905                     struct Foo {}
906 
907                     impl Foo {
908                         pub fn bar(&self) -> Result<(), ()> {
909                             unimplemented!()
910                         }
911                     }
912                 }
913             })
914             .to_token_stream()
915             .to_string()
916         ));
917     }
918 
919     #[test]
multilevel_borrows()920     fn multilevel_borrows() {
921         insta::assert_snapshot!(rustfmt_code(
922             &gen_bridge(parse_quote! {
923                 mod ffi {
924                     #[diplomat::opaque]
925                     struct Foo<'a>(&'a str);
926 
927                     #[diplomat::opaque]
928                     struct Bar<'b, 'a: 'b>(&'b Foo<'a>);
929 
930                     struct Baz<'x, 'y> {
931                         foo: &'y Foo<'x>,
932                     }
933 
934                     impl<'a> Foo<'a> {
935                         pub fn new(x: &'a str) -> Box<Foo<'a>> {
936                             unimplemented!()
937                         }
938 
939                         pub fn get_bar<'b>(&'b self) -> Box<Bar<'b, 'a>> {
940                             unimplemented!()
941                         }
942 
943                         pub fn get_baz<'b>(&'b self) -> Baz<'b, 'a> {
944                             Bax { foo: self }
945                         }
946                     }
947                 }
948             })
949             .to_token_stream()
950             .to_string()
951         ));
952     }
953 
954     #[test]
self_params()955     fn self_params() {
956         insta::assert_snapshot!(rustfmt_code(
957             &gen_bridge(parse_quote! {
958                 mod ffi {
959                     #[diplomat::opaque]
960                     struct RefList<'a> {
961                         data: &'a i32,
962                         next: Option<Box<Self>>,
963                     }
964 
965                     impl<'b> RefList<'b> {
966                         pub fn extend(&mut self, other: &Self) -> Self {
967                             unimplemented!()
968                         }
969                     }
970                 }
971             })
972             .to_token_stream()
973             .to_string()
974         ));
975     }
976 
977     #[test]
cfged_method()978     fn cfged_method() {
979         insta::assert_snapshot!(rustfmt_code(
980             &gen_bridge(parse_quote! {
981                 mod ffi {
982                     struct Foo {}
983 
984                     impl Foo {
985                         #[cfg(feature = "foo")]
986                         pub fn bar(s: u8) {
987                             unimplemented!()
988                         }
989                     }
990                 }
991             })
992             .to_token_stream()
993             .to_string()
994         ));
995 
996         insta::assert_snapshot!(rustfmt_code(
997             &gen_bridge(parse_quote! {
998                 mod ffi {
999                     struct Foo {}
1000 
1001                     #[cfg(feature = "bar")]
1002                     impl Foo {
1003                         #[cfg(feature = "foo")]
1004                         pub fn bar(s: u8) {
1005                             unimplemented!()
1006                         }
1007                     }
1008                 }
1009             })
1010             .to_token_stream()
1011             .to_string()
1012         ));
1013     }
1014 
1015     #[test]
cfgd_struct()1016     fn cfgd_struct() {
1017         insta::assert_snapshot!(rustfmt_code(
1018             &gen_bridge(parse_quote! {
1019                 mod ffi {
1020                     #[diplomat::opaque]
1021                     #[cfg(feature = "foo")]
1022                     struct Foo {}
1023                     #[cfg(feature = "foo")]
1024                     impl Foo {
1025                         pub fn bar(s: u8) {
1026                             unimplemented!()
1027                         }
1028                     }
1029                 }
1030             })
1031             .to_token_stream()
1032             .to_string()
1033         ));
1034     }
1035 
1036     #[test]
callback_arguments()1037     fn callback_arguments() {
1038         insta::assert_snapshot!(rustfmt_code(
1039             &gen_bridge(parse_quote! {
1040                 mod ffi {
1041                     pub struct Wrapper {
1042                         cant_be_empty: bool,
1043                     }
1044                     pub struct TestingStruct {
1045                         x: i32,
1046                         y: i32,
1047                     }
1048                     impl Wrapper {
1049                         pub fn test_multi_arg_callback(f: impl Fn(i32) -> i32, x: i32) -> i32 {
1050                             f(10 + x)
1051                         }
1052                         pub fn test_multiarg_void_callback(f: impl Fn(i32, &str)) {
1053                             f(-10, "hello it's a string\0");
1054                         }
1055                         pub fn test_mod_array(g: impl Fn(&[u8])) {
1056                             let bytes: Vec<u8> = vec![0x11, 0x22];
1057                             g(bytes.as_slice().into());
1058                         }
1059                         pub fn test_no_args(h: impl Fn()) -> i32 {
1060                             h();
1061                             -5
1062                         }
1063                         pub fn test_cb_with_struct(f: impl Fn(TestingStruct) -> i32) -> i32 {
1064                             let arg = TestingStruct {
1065                                 x: 1,
1066                                 y: 5,
1067                             };
1068                             f(arg)
1069                         }
1070                         pub fn test_multiple_cb_args(f: impl Fn() -> i32, g: impl Fn(i32) -> i32) -> i32 {
1071                             f() + g(5)
1072                         }
1073                     }
1074                 }
1075             })
1076             .to_token_stream()
1077             .to_string()
1078         ));
1079     }
1080 
1081     #[test]
traits()1082     fn traits() {
1083         insta::assert_snapshot!(rustfmt_code(
1084             &gen_bridge(parse_quote! {
1085                 mod ffi {
1086                     pub struct TestingStruct {
1087                         x: i32,
1088                         y: i32,
1089                     }
1090 
1091                     pub trait TesterTrait: std::marker::Send {
1092                         fn test_trait_fn(&self, x: i32) -> i32;
1093                         fn test_void_trait_fn(&self);
1094                         fn test_struct_trait_fn(&self, s: TestingStruct) -> i32;
1095                         fn test_slice_trait_fn(&self, s: &[u8]) -> i32;
1096                     }
1097 
1098                     pub struct Wrapper {
1099                         cant_be_empty: bool,
1100                     }
1101 
1102                     impl Wrapper {
1103                         pub fn test_with_trait(t: impl TesterTrait, x: i32) -> i32 {
1104                             t.test_void_trait_fn();
1105                             t.test_trait_fn(x)
1106                         }
1107 
1108                         pub fn test_trait_with_struct(t: impl TesterTrait) -> i32 {
1109                             let arg = TestingStruct {
1110                                 x: 1,
1111                                 y: 5,
1112                             };
1113                             t.test_struct_trait_fn(arg)
1114                         }
1115                     }
1116 
1117                 }
1118             })
1119             .to_token_stream()
1120             .to_string()
1121         ));
1122     }
1123 
1124     #[test]
both_kinds_of_option()1125     fn both_kinds_of_option() {
1126         insta::assert_snapshot!(rustfmt_code(
1127             &gen_bridge(parse_quote! {
1128                 mod ffi {
1129                     use diplomat_runtime::DiplomatOption;
1130                     #[diplomat::opaque]
1131                     struct Foo {}
1132                     struct CustomStruct {
1133                         num: u8,
1134                         b: bool,
1135                         diplo_option: DiplomatOption<u8>,
1136                     }
1137                     impl Foo {
1138                         pub fn diplo_option_u8(x: DiplomatOption<u8>) -> DiplomatOption<u8> {
1139                             x
1140                         }
1141                         pub fn diplo_option_ref(x: DiplomatOption<&Foo>) -> DiplomatOption<&Foo> {
1142                             x
1143                         }
1144                         pub fn diplo_option_box() -> DiplomatOption<Box<Foo>> {
1145                             x
1146                         }
1147                         pub fn diplo_option_struct(x: DiplomatOption<CustomStruct>) -> DiplomatOption<CustomStruct> {
1148                             x
1149                         }
1150                         pub fn option_u8(x: Option<u8>) -> Option<u8> {
1151                             x
1152                         }
1153                         pub fn option_ref(x: Option<&Foo>) -> Option<&Foo> {
1154                             x
1155                         }
1156                         pub fn option_box() -> Option<Box<Foo>> {
1157                             x
1158                         }
1159                         pub fn option_struct(x: Option<CustomStruct>) -> Option<CustomStruct> {
1160                             x
1161                         }
1162                     }
1163                 }
1164             })
1165             .to_token_stream()
1166             .to_string()
1167         ));
1168     }
1169 }
1170