• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Core dependency injection macros
2 
3 extern crate proc_macro;
4 use proc_macro::TokenStream;
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream, Result};
7 use syn::punctuated::Punctuated;
8 use syn::{
9     braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path,
10     Token, Type,
11 };
12 
13 /// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry
14 #[proc_macro_attribute]
provides(_attr: TokenStream, item: TokenStream) -> TokenStream15 pub fn provides(_attr: TokenStream, item: TokenStream) -> TokenStream {
16     let function: ItemFn = parse(item).expect("can only be applied to functions");
17 
18     // Create the info needed to refer to the function & the injected version we generate
19     let ident = function.sig.ident.clone();
20     let injected_ident = format_ident!("__gddi_{}_injected", ident);
21 
22     // Create the info needed to generate the call to the original function
23     let inputs = function.sig.inputs.iter().map(|arg| {
24         if let FnArg::Typed(t) = arg {
25             return t.ty.clone();
26         }
27         panic!("can't be applied to struct methods");
28     });
29     let local_var_idents = (0..inputs.len()).map(|i| format_ident!("__input{}", i));
30     let local_var_idents_for_call = local_var_idents.clone();
31 
32     let emitted_code = quote! {
33         // Injecting wrapper
34         fn #injected_ident(registry: std::sync::Arc<gddi::Registry>) -> std::pin::Pin<gddi::ProviderFutureBox> {
35             Box::pin(async move {
36                 // Create a local variable for each argument, to ensure they get generated in a
37                 // deterministic order (compiler complains otherwise)
38                 #(let #local_var_idents = registry.get::<#inputs>().await;)*
39 
40                 // Actually call the original function
41                 Box::new(#ident(#(#local_var_idents_for_call),*).await) as Box<dyn std::any::Any>
42             })
43         }
44         #function
45     };
46     emitted_code.into()
47 }
48 
49 struct ModuleDef {
50     name: Ident,
51     providers: Punctuated<ProviderDef, Token![,]>,
52     submodules: Punctuated<Path, Token![,]>,
53 }
54 
55 enum ModuleEntry {
56     Providers(Punctuated<ProviderDef, Token![,]>),
57     Submodules(Punctuated<Path, Token![,]>),
58 }
59 
60 struct ProviderDef {
61     ty: Type,
62     ident: Ident,
63     parts: bool,
64 }
65 
66 impl Parse for ModuleDef {
parse(input: ParseStream) -> Result<Self>67     fn parse(input: ParseStream) -> Result<Self> {
68         // first thing is the module name followed by a comma
69         let name = input.parse()?;
70         input.parse::<Token![,]>()?;
71         // Then comes submodules or provider sections, in any order
72         let entries: Punctuated<ModuleEntry, Token![,]> = Punctuated::parse_terminated(input)?;
73         let mut providers = Punctuated::new();
74         let mut submodules = Punctuated::new();
75         for entry in entries.into_iter() {
76             match entry {
77                 ModuleEntry::Providers(value) => {
78                     if !providers.is_empty() {
79                         panic!("providers specified more than once");
80                     }
81                     providers = value;
82                 }
83                 ModuleEntry::Submodules(value) => {
84                     if !submodules.is_empty() {
85                         panic!("submodules specified more than once");
86                     }
87                     submodules = value;
88                 }
89             }
90         }
91         Ok(ModuleDef { name, providers, submodules })
92     }
93 }
94 
95 impl Parse for ProviderDef {
parse(input: ParseStream) -> Result<Self>96     fn parse(input: ParseStream) -> Result<Self> {
97         let parts = input.peek3(Token![=>]);
98         if parts {
99             match input.parse::<Ident>()?.to_string().as_str() {
100                 "parts" => {}
101                 keyword => panic!("expected 'parts', got '{}'", keyword),
102             }
103         }
104 
105         // A provider definition follows this format: <Type> -> <function name>
106         let ty = input.parse()?;
107         input.parse::<Token![=>]>()?;
108         let ident = input.parse()?;
109         Ok(ProviderDef { ty, ident, parts })
110     }
111 }
112 
113 impl Parse for ModuleEntry {
parse(input: ParseStream) -> Result<Self>114     fn parse(input: ParseStream) -> Result<Self> {
115         match input.parse::<Ident>()?.to_string().as_str() {
116             "providers" => {
117                 let entries;
118                 braced!(entries in input);
119                 Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?))
120             }
121             "submodules" => {
122                 let entries;
123                 braced!(entries in input);
124                 Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?))
125             }
126             keyword => {
127                 panic!("unexpected keyword: {}", keyword);
128             }
129         }
130     }
131 }
132 
133 /// Emits a module function that registers submodules & providers with the registry
134 #[proc_macro]
module(item: TokenStream) -> TokenStream135 pub fn module(item: TokenStream) -> TokenStream {
136     let module = parse_macro_input!(item as ModuleDef);
137     let init_ident = module.name.clone();
138     let providers = module.providers.iter();
139     let types = providers.clone().map(|p| p.ty.clone());
140     let provider_idents =
141         providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
142     let parting_functions = providers.filter_map(|p| match &p.ty {
143         Type::Path(ty) if p.parts => Some(format_ident!(
144             "__gddi_part_out_{}",
145             ty.path.get_ident().unwrap().to_string().to_lowercase()
146         )),
147         _ => None,
148     });
149     let submodule_idents = module.submodules.iter();
150     let emitted_code = quote! {
151         #[doc(hidden)]
152         #[allow(missing_docs)]
153         pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
154             // Register all providers on this module
155             let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
156             // Register all submodules on this module
157             #(.register_module(#submodule_idents))*;
158 
159             #(let ret = #parting_functions(ret);)*
160 
161             ret
162         }
163     };
164     emitted_code.into()
165 }
166 
167 /// Emits a default implementation for Stoppable that does nothing;
168 #[proc_macro_derive(Stoppable)]
derive_nop_stop(item: TokenStream) -> TokenStream169 pub fn derive_nop_stop(item: TokenStream) -> TokenStream {
170     let input = parse_macro_input!(item as DeriveInput);
171     let ident = input.ident;
172     let emitted_code = quote! {
173         impl gddi::Stoppable for #ident {}
174     };
175     emitted_code.into()
176 }
177 
178 /// Generates the code necessary to split up a type into its components
179 #[proc_macro_attribute]
part_out(_attr: TokenStream, item: TokenStream) -> TokenStream180 pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream {
181     let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions");
182     let struct_ident = struct_.ident.clone();
183     let fields = match struct_.fields.clone() {
184         Fields::Named(f) => f,
185         _ => panic!("can only be applied to structs with named fields"),
186     }
187     .named;
188 
189     let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name"));
190     let field_types = fields.iter().map(|f| f.ty.clone());
191 
192     let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase());
193 
194     let emitted_code = quote! {
195         #struct_
196 
197         fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
198             builder#(.register_provider::<#field_types>(Box::new(
199                 |registry: std::sync::Arc<gddi::Registry>| -> std::pin::Pin<gddi::ProviderFutureBox> {
200                     Box::pin(async move {
201                         Box::new(async move {
202                             registry.get::<#struct_ident>().await.#field_names
203                         }.await) as Box<dyn std::any::Any>
204                     })
205                 })))*
206         }
207     };
208     emitted_code.into()
209 }
210