1 //! Core dependency injection macros
2
3 extern crate proc_macro;
4 use proc_macro::TokenStream;
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream, Result};
7 use syn::punctuated::Punctuated;
8 use syn::{
9 braced, parse, parse_macro_input, DeriveInput, Fields, FnArg, Ident, ItemFn, ItemStruct, Path,
10 Token, Type,
11 };
12
13 /// Defines a provider function, with generated helper that implicitly fetches argument instances from the registry
14 #[proc_macro_attribute]
provides(_attr: TokenStream, item: TokenStream) -> TokenStream15 pub fn provides(_attr: TokenStream, item: TokenStream) -> TokenStream {
16 let function: ItemFn = parse(item).expect("can only be applied to functions");
17
18 // Create the info needed to refer to the function & the injected version we generate
19 let ident = function.sig.ident.clone();
20 let injected_ident = format_ident!("__gddi_{}_injected", ident);
21
22 // Create the info needed to generate the call to the original function
23 let inputs = function.sig.inputs.iter().map(|arg| {
24 if let FnArg::Typed(t) = arg {
25 return t.ty.clone();
26 }
27 panic!("can't be applied to struct methods");
28 });
29 let local_var_idents = (0..inputs.len()).map(|i| format_ident!("__input{}", i));
30 let local_var_idents_for_call = local_var_idents.clone();
31
32 let emitted_code = quote! {
33 // Injecting wrapper
34 fn #injected_ident(registry: std::sync::Arc<gddi::Registry>) -> std::pin::Pin<gddi::ProviderFutureBox> {
35 Box::pin(async move {
36 // Create a local variable for each argument, to ensure they get generated in a
37 // deterministic order (compiler complains otherwise)
38 #(let #local_var_idents = registry.get::<#inputs>().await;)*
39
40 // Actually call the original function
41 Box::new(#ident(#(#local_var_idents_for_call),*).await) as Box<dyn std::any::Any>
42 })
43 }
44 #function
45 };
46 emitted_code.into()
47 }
48
49 struct ModuleDef {
50 name: Ident,
51 providers: Punctuated<ProviderDef, Token![,]>,
52 submodules: Punctuated<Path, Token![,]>,
53 }
54
55 enum ModuleEntry {
56 Providers(Punctuated<ProviderDef, Token![,]>),
57 Submodules(Punctuated<Path, Token![,]>),
58 }
59
60 struct ProviderDef {
61 ty: Type,
62 ident: Ident,
63 parts: bool,
64 }
65
66 impl Parse for ModuleDef {
parse(input: ParseStream) -> Result<Self>67 fn parse(input: ParseStream) -> Result<Self> {
68 // first thing is the module name followed by a comma
69 let name = input.parse()?;
70 input.parse::<Token![,]>()?;
71 // Then comes submodules or provider sections, in any order
72 let entries: Punctuated<ModuleEntry, Token![,]> = Punctuated::parse_terminated(input)?;
73 let mut providers = Punctuated::new();
74 let mut submodules = Punctuated::new();
75 for entry in entries.into_iter() {
76 match entry {
77 ModuleEntry::Providers(value) => {
78 if !providers.is_empty() {
79 panic!("providers specified more than once");
80 }
81 providers = value;
82 }
83 ModuleEntry::Submodules(value) => {
84 if !submodules.is_empty() {
85 panic!("submodules specified more than once");
86 }
87 submodules = value;
88 }
89 }
90 }
91 Ok(ModuleDef { name, providers, submodules })
92 }
93 }
94
95 impl Parse for ProviderDef {
parse(input: ParseStream) -> Result<Self>96 fn parse(input: ParseStream) -> Result<Self> {
97 let parts = input.peek3(Token![=>]);
98 if parts {
99 match input.parse::<Ident>()?.to_string().as_str() {
100 "parts" => {}
101 keyword => panic!("expected 'parts', got '{}'", keyword),
102 }
103 }
104
105 // A provider definition follows this format: <Type> -> <function name>
106 let ty = input.parse()?;
107 input.parse::<Token![=>]>()?;
108 let ident = input.parse()?;
109 Ok(ProviderDef { ty, ident, parts })
110 }
111 }
112
113 impl Parse for ModuleEntry {
parse(input: ParseStream) -> Result<Self>114 fn parse(input: ParseStream) -> Result<Self> {
115 match input.parse::<Ident>()?.to_string().as_str() {
116 "providers" => {
117 let entries;
118 braced!(entries in input);
119 Ok(ModuleEntry::Providers(entries.parse_terminated(ProviderDef::parse)?))
120 }
121 "submodules" => {
122 let entries;
123 braced!(entries in input);
124 Ok(ModuleEntry::Submodules(entries.parse_terminated(Path::parse)?))
125 }
126 keyword => {
127 panic!("unexpected keyword: {}", keyword);
128 }
129 }
130 }
131 }
132
133 /// Emits a module function that registers submodules & providers with the registry
134 #[proc_macro]
module(item: TokenStream) -> TokenStream135 pub fn module(item: TokenStream) -> TokenStream {
136 let module = parse_macro_input!(item as ModuleDef);
137 let init_ident = module.name.clone();
138 let providers = module.providers.iter();
139 let types = providers.clone().map(|p| p.ty.clone());
140 let provider_idents =
141 providers.clone().map(|p| format_ident!("__gddi_{}_injected", p.ident.clone()));
142 let parting_functions = providers.filter_map(|p| match &p.ty {
143 Type::Path(ty) if p.parts => Some(format_ident!(
144 "__gddi_part_out_{}",
145 ty.path.get_ident().unwrap().to_string().to_lowercase()
146 )),
147 _ => None,
148 });
149 let submodule_idents = module.submodules.iter();
150 let emitted_code = quote! {
151 #[doc(hidden)]
152 #[allow(missing_docs)]
153 pub fn #init_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
154 // Register all providers on this module
155 let ret = builder#(.register_provider::<#types>(Box::new(#provider_idents)))*
156 // Register all submodules on this module
157 #(.register_module(#submodule_idents))*;
158
159 #(let ret = #parting_functions(ret);)*
160
161 ret
162 }
163 };
164 emitted_code.into()
165 }
166
167 /// Emits a default implementation for Stoppable that does nothing;
168 #[proc_macro_derive(Stoppable)]
derive_nop_stop(item: TokenStream) -> TokenStream169 pub fn derive_nop_stop(item: TokenStream) -> TokenStream {
170 let input = parse_macro_input!(item as DeriveInput);
171 let ident = input.ident;
172 let emitted_code = quote! {
173 impl gddi::Stoppable for #ident {}
174 };
175 emitted_code.into()
176 }
177
178 /// Generates the code necessary to split up a type into its components
179 #[proc_macro_attribute]
part_out(_attr: TokenStream, item: TokenStream) -> TokenStream180 pub fn part_out(_attr: TokenStream, item: TokenStream) -> TokenStream {
181 let struct_: ItemStruct = parse(item).expect("can only be applied to struct definitions");
182 let struct_ident = struct_.ident.clone();
183 let fields = match struct_.fields.clone() {
184 Fields::Named(f) => f,
185 _ => panic!("can only be applied to structs with named fields"),
186 }
187 .named;
188
189 let field_names = fields.iter().map(|f| f.ident.clone().expect("field without a name"));
190 let field_types = fields.iter().map(|f| f.ty.clone());
191
192 let fn_ident = format_ident!("__gddi_part_out_{}", struct_ident.to_string().to_lowercase());
193
194 let emitted_code = quote! {
195 #struct_
196
197 fn #fn_ident(builder: gddi::RegistryBuilder) -> gddi::RegistryBuilder {
198 builder#(.register_provider::<#field_types>(Box::new(
199 |registry: std::sync::Arc<gddi::Registry>| -> std::pin::Pin<gddi::ProviderFutureBox> {
200 Box::pin(async move {
201 Box::new(async move {
202 registry.get::<#struct_ident>().await.#field_names
203 }.await) as Box<dyn std::any::Any>
204 })
205 })))*
206 }
207 };
208 emitted_code.into()
209 }
210