1 //! The futures-rs `select! macro implementation.
2
3 use proc_macro::TokenStream;
4 use proc_macro2::Span;
5 use quote::{format_ident, quote};
6 use syn::{parse_quote, Expr, Ident, Pat, Token};
7 use syn::parse::{Parse, ParseStream};
8
9 mod kw {
10 syn::custom_keyword!(complete);
11 }
12
13 struct Select {
14 // span of `complete`, then expression after `=> ...`
15 complete: Option<Expr>,
16 default: Option<Expr>,
17 normal_fut_exprs: Vec<Expr>,
18 normal_fut_handlers: Vec<(Pat, Expr)>,
19 }
20
21 #[allow(clippy::large_enum_variant)]
22 enum CaseKind {
23 Complete,
24 Default,
25 Normal(Pat, Expr),
26 }
27
28 impl Parse for Select {
parse(input: ParseStream<'_>) -> syn::Result<Self>29 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30 let mut select = Self {
31 complete: None,
32 default: None,
33 normal_fut_exprs: vec![],
34 normal_fut_handlers: vec![],
35 };
36
37 while !input.is_empty() {
38 let case_kind = if input.peek(kw::complete) {
39 // `complete`
40 if select.complete.is_some() {
41 return Err(input.error("multiple `complete` cases found, only one allowed"));
42 }
43 input.parse::<kw::complete>()?;
44 CaseKind::Complete
45 } else if input.peek(Token![default]) {
46 // `default`
47 if select.default.is_some() {
48 return Err(input.error("multiple `default` cases found, only one allowed"));
49 }
50 input.parse::<Ident>()?;
51 CaseKind::Default
52 } else {
53 // `<pat> = <expr>`
54 let pat = input.parse()?;
55 input.parse::<Token![=]>()?;
56 let expr = input.parse()?;
57 CaseKind::Normal(pat, expr)
58 };
59
60 // `=> <expr>`
61 input.parse::<Token![=>]>()?;
62 let expr = input.parse::<Expr>()?;
63
64 // Commas after the expression are only optional if it's a `Block`
65 // or it is the last branch in the `match`.
66 let is_block = match expr { Expr::Block(_) => true, _ => false };
67 if is_block || input.is_empty() {
68 input.parse::<Option<Token![,]>>()?;
69 } else {
70 input.parse::<Token![,]>()?;
71 }
72
73 match case_kind {
74 CaseKind::Complete => select.complete = Some(expr),
75 CaseKind::Default => select.default = Some(expr),
76 CaseKind::Normal(pat, fut_expr) => {
77 select.normal_fut_exprs.push(fut_expr);
78 select.normal_fut_handlers.push((pat, expr));
79 },
80 }
81 }
82
83 Ok(select)
84 }
85 }
86
87 // Enum over all the cases in which the `select!` waiting has completed and the result
88 // can be processed.
89 //
90 // `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
declare_result_enum( result_ident: Ident, variants: usize, complete: bool, span: Span ) -> (Vec<Ident>, syn::ItemEnum)91 fn declare_result_enum(
92 result_ident: Ident,
93 variants: usize,
94 complete: bool,
95 span: Span
96 ) -> (Vec<Ident>, syn::ItemEnum) {
97 // "_0", "_1", "_2"
98 let variant_names: Vec<Ident> =
99 (0..variants)
100 .map(|num| format_ident!("_{}", num, span = span))
101 .collect();
102
103 let type_parameters = &variant_names;
104 let variants = &variant_names;
105
106 let complete_variant = if complete {
107 Some(quote!(Complete))
108 } else {
109 None
110 };
111
112 let enum_item = parse_quote! {
113 enum #result_ident<#(#type_parameters,)*> {
114 #(
115 #variants(#type_parameters),
116 )*
117 #complete_variant
118 }
119 };
120
121 (variant_names, enum_item)
122 }
123
124 /// The `select!` macro.
select(input: TokenStream) -> TokenStream125 pub(crate) fn select(input: TokenStream) -> TokenStream {
126 select_inner(input, true)
127 }
128
129 /// The `select_biased!` macro.
select_biased(input: TokenStream) -> TokenStream130 pub(crate) fn select_biased(input: TokenStream) -> TokenStream {
131 select_inner(input, false)
132 }
133
select_inner(input: TokenStream, random: bool) -> TokenStream134 fn select_inner(input: TokenStream, random: bool) -> TokenStream {
135 let parsed = syn::parse_macro_input!(input as Select);
136
137 // should be def_site, but that's unstable
138 let span = Span::call_site();
139
140 let enum_ident = Ident::new("__PrivResult", span);
141
142 let (variant_names, enum_item) = declare_result_enum(
143 enum_ident.clone(),
144 parsed.normal_fut_exprs.len(),
145 parsed.complete.is_some(),
146 span,
147 );
148
149 // bind non-`Ident` future exprs w/ `let`
150 let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
151 let bound_future_names: Vec<_> = parsed.normal_fut_exprs.into_iter()
152 .zip(variant_names.iter())
153 .map(|(expr, variant_name)| {
154 match expr {
155 syn::Expr::Path(path) => {
156 // Don't bind futures that are already a path.
157 // This prevents creating redundant stack space
158 // for them.
159 // Passing Futures by path requires those Futures to implement Unpin.
160 // We check for this condition here in order to be able to
161 // safely use Pin::new_unchecked(&mut #path) later on.
162 future_let_bindings.push(quote! {
163 __futures_crate::async_await::assert_fused_future(&#path);
164 __futures_crate::async_await::assert_unpin(&#path);
165 });
166 path
167 },
168 _ => {
169 // Bind and pin the resulting Future on the stack. This is
170 // necessary to support direct select! calls on !Unpin
171 // Futures. The Future is not explicitly pinned here with
172 // a Pin call, but assumed as pinned. The actual Pin is
173 // created inside the poll() function below to defer the
174 // creation of the temporary pointer, which would otherwise
175 // increase the size of the generated Future.
176 // Safety: This is safe since the lifetime of the Future
177 // is totally constraint to the lifetime of the select!
178 // expression, and the Future can't get moved inside it
179 // (it is shadowed).
180 future_let_bindings.push(quote! {
181 let mut #variant_name = #expr;
182 });
183 parse_quote! { #variant_name }
184 }
185 }
186 })
187 .collect();
188
189 // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
190 // to use for polling that individual future. These will then be put in an array.
191 let poll_functions = bound_future_names.iter().zip(variant_names.iter())
192 .map(|(bound_future_name, variant_name)| {
193 // Below we lazily create the Pin on the Future below.
194 // This is done in order to avoid allocating memory in the generator
195 // for the Pin variable.
196 // Safety: This is safe because one of the following condition applies:
197 // 1. The Future is passed by the caller by name, and we assert that
198 // it implements Unpin.
199 // 2. The Future is created in scope of the select! function and will
200 // not be moved for the duration of it. It is thereby stack-pinned
201 quote! {
202 let mut #variant_name = |__cx: &mut __futures_crate::task::Context<'_>| {
203 let mut #bound_future_name = unsafe {
204 __futures_crate::Pin::new_unchecked(&mut #bound_future_name)
205 };
206 if __futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
207 __futures_crate::None
208 } else {
209 __futures_crate::Some(__futures_crate::future::FutureExt::poll_unpin(
210 &mut #bound_future_name,
211 __cx,
212 ).map(#enum_ident::#variant_name))
213 }
214 };
215 let #variant_name: &mut dyn FnMut(
216 &mut __futures_crate::task::Context<'_>
217 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = &mut #variant_name;
218 }
219 });
220
221 let none_polled = if parsed.complete.is_some() {
222 quote! {
223 __futures_crate::task::Poll::Ready(#enum_ident::Complete)
224 }
225 } else {
226 quote! {
227 panic!("all futures in select! were completed,\
228 but no `complete =>` handler was provided")
229 }
230 };
231
232 let branches = parsed.normal_fut_handlers.into_iter()
233 .zip(variant_names.iter())
234 .map(|((pat, expr), variant_name)| {
235 quote! {
236 #enum_ident::#variant_name(#pat) => { #expr },
237 }
238 });
239 let branches = quote! { #( #branches )* };
240
241 let complete_branch = parsed.complete.map(|complete_expr| {
242 quote! {
243 #enum_ident::Complete => { #complete_expr },
244 }
245 });
246
247 let branches = quote! {
248 #branches
249 #complete_branch
250 };
251
252 let await_select_fut = if parsed.default.is_some() {
253 // For select! with default this returns the Poll result
254 quote! {
255 __poll_fn(&mut __futures_crate::task::Context::from_waker(
256 __futures_crate::task::noop_waker_ref()
257 ))
258 }
259 } else {
260 quote! {
261 __futures_crate::future::poll_fn(__poll_fn).await
262 }
263 };
264
265 let execute_result_expr = if let Some(default_expr) = &parsed.default {
266 // For select! with default __select_result is a Poll, otherwise not
267 quote! {
268 match __select_result {
269 __futures_crate::task::Poll::Ready(result) => match result {
270 #branches
271 },
272 _ => #default_expr
273 }
274 }
275 } else {
276 quote! {
277 match __select_result {
278 #branches
279 }
280 }
281 };
282
283 let shuffle = if random {
284 quote! {
285 __futures_crate::async_await::shuffle(&mut __select_arr);
286 }
287 } else {
288 quote!()
289 };
290
291 TokenStream::from(quote! { {
292 #enum_item
293
294 let __select_result = {
295 #( #future_let_bindings )*
296
297 let mut __poll_fn = |__cx: &mut __futures_crate::task::Context<'_>| {
298 let mut __any_polled = false;
299
300 #( #poll_functions )*
301
302 let mut __select_arr = [#( #variant_names ),*];
303 #shuffle
304 for poller in &mut __select_arr {
305 let poller: &mut &mut dyn FnMut(
306 &mut __futures_crate::task::Context<'_>
307 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = poller;
308 match poller(__cx) {
309 __futures_crate::Some(x @ __futures_crate::task::Poll::Ready(_)) =>
310 return x,
311 __futures_crate::Some(__futures_crate::task::Poll::Pending) => {
312 __any_polled = true;
313 }
314 __futures_crate::None => {}
315 }
316 }
317
318 if !__any_polled {
319 #none_polled
320 } else {
321 __futures_crate::task::Poll::Pending
322 }
323 };
324
325 #await_select_fut
326 };
327
328 #execute_result_expr
329 } })
330 }
331