• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! This is the implementation of the pass which transforms generators into state machines.
2 //!
3 //! MIR generation for generators creates a function which has a self argument which
4 //! passes by value. This argument is effectively a generator type which only contains upvars and
5 //! is only used for this argument inside the MIR for the generator.
6 //! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that
7 //! MIR before this pass and creates drop flags for MIR locals.
8 //! It will also drop the generator argument (which only consists of upvars) if any of the upvars
9 //! are moved out of. This pass elaborates the drops of upvars / generator argument in the case
10 //! that none of the upvars were moved out of. This is because we cannot have any drops of this
11 //! generator in the MIR, since it is used to create the drop glue for the generator. We'd get
12 //! infinite recursion otherwise.
13 //!
14 //! This pass creates the implementation for either the `Generator::resume` or `Future::poll`
15 //! function and the drop shim for the generator based on the MIR input.
16 //! It converts the generator argument from Self to &mut Self adding derefs in the MIR as needed.
17 //! It computes the final layout of the generator struct which looks like this:
18 //!     First upvars are stored
19 //!     It is followed by the generator state field.
20 //!     Then finally the MIR locals which are live across a suspension point are stored.
21 //!     ```ignore (illustrative)
22 //!     struct Generator {
23 //!         upvars...,
24 //!         state: u32,
25 //!         mir_locals...,
26 //!     }
27 //!     ```
28 //! This pass computes the meaning of the state field and the MIR locals which are live
29 //! across a suspension point. There are however three hardcoded generator states:
30 //!     0 - Generator have not been resumed yet
31 //!     1 - Generator has returned / is completed
32 //!     2 - Generator has been poisoned
33 //!
34 //! It also rewrites `return x` and `yield y` as setting a new generator state and returning
35 //! `GeneratorState::Complete(x)` and `GeneratorState::Yielded(y)`,
36 //! or `Poll::Ready(x)` and `Poll::Pending` respectively.
37 //! MIR locals which are live across a suspension point are moved to the generator struct
38 //! with references to them being updated with references to the generator struct.
39 //!
40 //! The pass creates two functions which have a switch on the generator state giving
41 //! the action to take.
42 //!
43 //! One of them is the implementation of `Generator::resume` / `Future::poll`.
44 //! For generators with state 0 (unresumed) it starts the execution of the generator.
45 //! For generators with state 1 (returned) and state 2 (poisoned) it panics.
46 //! Otherwise it continues the execution from the last suspension point.
47 //!
48 //! The other function is the drop glue for the generator.
49 //! For generators with state 0 (unresumed) it drops the upvars of the generator.
50 //! For generators with state 1 (returned) and state 2 (poisoned) it does nothing.
51 //! Otherwise it drops all the values in scope at the last suspension point.
52 
53 use crate::deref_separator::deref_finder;
54 use crate::errors;
55 use crate::simplify;
56 use crate::MirPass;
57 use rustc_data_structures::fx::{FxHashMap, FxHashSet};
58 use rustc_errors::pluralize;
59 use rustc_hir as hir;
60 use rustc_hir::lang_items::LangItem;
61 use rustc_hir::GeneratorKind;
62 use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
63 use rustc_index::{Idx, IndexVec};
64 use rustc_middle::mir::dump_mir;
65 use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
66 use rustc_middle::mir::*;
67 use rustc_middle::ty::{self, AdtDef, Ty, TyCtxt};
68 use rustc_middle::ty::{GeneratorSubsts, SubstsRef};
69 use rustc_mir_dataflow::impls::{
70     MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
71 };
72 use rustc_mir_dataflow::storage::always_storage_live_locals;
73 use rustc_mir_dataflow::{self, Analysis};
74 use rustc_span::def_id::{DefId, LocalDefId};
75 use rustc_span::symbol::sym;
76 use rustc_span::Span;
77 use rustc_target::abi::{FieldIdx, VariantIdx};
78 use rustc_target::spec::PanicStrategy;
79 use std::{iter, ops};
80 
81 pub struct StateTransform;
82 
83 struct RenameLocalVisitor<'tcx> {
84     from: Local,
85     to: Local,
86     tcx: TyCtxt<'tcx>,
87 }
88 
89 impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
tcx(&self) -> TyCtxt<'tcx>90     fn tcx(&self) -> TyCtxt<'tcx> {
91         self.tcx
92     }
93 
visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location)94     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
95         if *local == self.from {
96             *local = self.to;
97         }
98     }
99 
visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location)100     fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
101         match terminator.kind {
102             TerminatorKind::Return => {
103                 // Do not replace the implicit `_0` access here, as that's not possible. The
104                 // transform already handles `return` correctly.
105             }
106             _ => self.super_terminator(terminator, location),
107         }
108     }
109 }
110 
111 struct DerefArgVisitor<'tcx> {
112     tcx: TyCtxt<'tcx>,
113 }
114 
115 impl<'tcx> MutVisitor<'tcx> for DerefArgVisitor<'tcx> {
tcx(&self) -> TyCtxt<'tcx>116     fn tcx(&self) -> TyCtxt<'tcx> {
117         self.tcx
118     }
119 
visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location)120     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
121         assert_ne!(*local, SELF_ARG);
122     }
123 
visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location)124     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
125         if place.local == SELF_ARG {
126             replace_base(
127                 place,
128                 Place {
129                     local: SELF_ARG,
130                     projection: self.tcx().mk_place_elems(&[ProjectionElem::Deref]),
131                 },
132                 self.tcx,
133             );
134         } else {
135             self.visit_local(&mut place.local, context, location);
136 
137             for elem in place.projection.iter() {
138                 if let PlaceElem::Index(local) = elem {
139                     assert_ne!(local, SELF_ARG);
140                 }
141             }
142         }
143     }
144 }
145 
146 struct PinArgVisitor<'tcx> {
147     ref_gen_ty: Ty<'tcx>,
148     tcx: TyCtxt<'tcx>,
149 }
150 
151 impl<'tcx> MutVisitor<'tcx> for PinArgVisitor<'tcx> {
tcx(&self) -> TyCtxt<'tcx>152     fn tcx(&self) -> TyCtxt<'tcx> {
153         self.tcx
154     }
155 
visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location)156     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
157         assert_ne!(*local, SELF_ARG);
158     }
159 
visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location)160     fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
161         if place.local == SELF_ARG {
162             replace_base(
163                 place,
164                 Place {
165                     local: SELF_ARG,
166                     projection: self.tcx().mk_place_elems(&[ProjectionElem::Field(
167                         FieldIdx::new(0),
168                         self.ref_gen_ty,
169                     )]),
170                 },
171                 self.tcx,
172             );
173         } else {
174             self.visit_local(&mut place.local, context, location);
175 
176             for elem in place.projection.iter() {
177                 if let PlaceElem::Index(local) = elem {
178                     assert_ne!(local, SELF_ARG);
179                 }
180             }
181         }
182     }
183 }
184 
replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>)185 fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
186     place.local = new_base.local;
187 
188     let mut new_projection = new_base.projection.to_vec();
189     new_projection.append(&mut place.projection.to_vec());
190 
191     place.projection = tcx.mk_place_elems(&new_projection);
192 }
193 
194 const SELF_ARG: Local = Local::from_u32(1);
195 
196 /// Generator has not been resumed yet.
197 const UNRESUMED: usize = GeneratorSubsts::UNRESUMED;
198 /// Generator has returned / is completed.
199 const RETURNED: usize = GeneratorSubsts::RETURNED;
200 /// Generator has panicked and is poisoned.
201 const POISONED: usize = GeneratorSubsts::POISONED;
202 
203 /// Number of variants to reserve in generator state. Corresponds to
204 /// `UNRESUMED` (beginning of a generator) and `RETURNED`/`POISONED`
205 /// (end of a generator) states.
206 const RESERVED_VARIANTS: usize = 3;
207 
208 /// A `yield` point in the generator.
209 struct SuspensionPoint<'tcx> {
210     /// State discriminant used when suspending or resuming at this point.
211     state: usize,
212     /// The block to jump to after resumption.
213     resume: BasicBlock,
214     /// Where to move the resume argument after resumption.
215     resume_arg: Place<'tcx>,
216     /// Which block to jump to if the generator is dropped in this state.
217     drop: Option<BasicBlock>,
218     /// Set of locals that have live storage while at this suspension point.
219     storage_liveness: GrowableBitSet<Local>,
220 }
221 
222 struct TransformVisitor<'tcx> {
223     tcx: TyCtxt<'tcx>,
224     is_async_kind: bool,
225     state_adt_ref: AdtDef<'tcx>,
226     state_substs: SubstsRef<'tcx>,
227 
228     // The type of the discriminant in the generator struct
229     discr_ty: Ty<'tcx>,
230 
231     // Mapping from Local to (type of local, generator struct index)
232     // FIXME(eddyb) This should use `IndexVec<Local, Option<_>>`.
233     remap: FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>,
234 
235     // A map from a suspension point in a block to the locals which have live storage at that point
236     storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>,
237 
238     // A list of suspension points, generated during the transform
239     suspension_points: Vec<SuspensionPoint<'tcx>>,
240 
241     // The set of locals that have no `StorageLive`/`StorageDead` annotations.
242     always_live_locals: BitSet<Local>,
243 
244     // The original RETURN_PLACE local
245     new_ret_local: Local,
246 }
247 
248 impl<'tcx> TransformVisitor<'tcx> {
249     // Make a `GeneratorState` or `Poll` variant assignment.
250     //
251     // `core::ops::GeneratorState` only has single element tuple variants,
252     // so we can just write to the downcasted first field and then set the
253     // discriminant to the appropriate variant.
make_state( &self, val: Operand<'tcx>, source_info: SourceInfo, is_return: bool, statements: &mut Vec<Statement<'tcx>>, )254     fn make_state(
255         &self,
256         val: Operand<'tcx>,
257         source_info: SourceInfo,
258         is_return: bool,
259         statements: &mut Vec<Statement<'tcx>>,
260     ) {
261         let idx = VariantIdx::new(match (is_return, self.is_async_kind) {
262             (true, false) => 1,  // GeneratorState::Complete
263             (false, false) => 0, // GeneratorState::Yielded
264             (true, true) => 0,   // Poll::Ready
265             (false, true) => 1,  // Poll::Pending
266         });
267 
268         let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_substs, None, None);
269 
270         // `Poll::Pending`
271         if self.is_async_kind && idx == VariantIdx::new(1) {
272             assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0);
273 
274             // FIXME(swatinem): assert that `val` is indeed unit?
275             statements.push(Statement {
276                 kind: StatementKind::Assign(Box::new((
277                     Place::return_place(),
278                     Rvalue::Aggregate(Box::new(kind), IndexVec::new()),
279                 ))),
280                 source_info,
281             });
282             return;
283         }
284 
285         // else: `Poll::Ready(x)`, `GeneratorState::Yielded(x)` or `GeneratorState::Complete(x)`
286         assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1);
287 
288         statements.push(Statement {
289             kind: StatementKind::Assign(Box::new((
290                 Place::return_place(),
291                 Rvalue::Aggregate(Box::new(kind), [val].into()),
292             ))),
293             source_info,
294         });
295     }
296 
297     // Create a Place referencing a generator struct field
make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx>298     fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
299         let self_place = Place::from(SELF_ARG);
300         let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
301         let mut projection = base.projection.to_vec();
302         projection.push(ProjectionElem::Field(idx, ty));
303 
304         Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
305     }
306 
307     // Create a statement which changes the discriminant
set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx>308     fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
309         let self_place = Place::from(SELF_ARG);
310         Statement {
311             source_info,
312             kind: StatementKind::SetDiscriminant {
313                 place: Box::new(self_place),
314                 variant_index: state_disc,
315             },
316         }
317     }
318 
319     // Create a statement which reads the discriminant into a temporary
get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>)320     fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
321         let temp_decl = LocalDecl::new(self.discr_ty, body.span).internal();
322         let local_decls_len = body.local_decls.push(temp_decl);
323         let temp = Place::from(local_decls_len);
324 
325         let self_place = Place::from(SELF_ARG);
326         let assign = Statement {
327             source_info: SourceInfo::outermost(body.span),
328             kind: StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
329         };
330         (assign, temp)
331     }
332 }
333 
334 impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
tcx(&self) -> TyCtxt<'tcx>335     fn tcx(&self) -> TyCtxt<'tcx> {
336         self.tcx
337     }
338 
visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location)339     fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
340         assert_eq!(self.remap.get(local), None);
341     }
342 
visit_place( &mut self, place: &mut Place<'tcx>, _context: PlaceContext, _location: Location, )343     fn visit_place(
344         &mut self,
345         place: &mut Place<'tcx>,
346         _context: PlaceContext,
347         _location: Location,
348     ) {
349         // Replace an Local in the remap with a generator struct access
350         if let Some(&(ty, variant_index, idx)) = self.remap.get(&place.local) {
351             replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
352         }
353     }
354 
visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>)355     fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
356         // Remove StorageLive and StorageDead statements for remapped locals
357         data.retain_statements(|s| match s.kind {
358             StatementKind::StorageLive(l) | StatementKind::StorageDead(l) => {
359                 !self.remap.contains_key(&l)
360             }
361             _ => true,
362         });
363 
364         let ret_val = match data.terminator().kind {
365             TerminatorKind::Return => {
366                 Some((true, None, Operand::Move(Place::from(self.new_ret_local)), None))
367             }
368             TerminatorKind::Yield { ref value, resume, resume_arg, drop } => {
369                 Some((false, Some((resume, resume_arg)), value.clone(), drop))
370             }
371             _ => None,
372         };
373 
374         if let Some((is_return, resume, v, drop)) = ret_val {
375             let source_info = data.terminator().source_info;
376             // We must assign the value first in case it gets declared dead below
377             self.make_state(v, source_info, is_return, &mut data.statements);
378             let state = if let Some((resume, mut resume_arg)) = resume {
379                 // Yield
380                 let state = RESERVED_VARIANTS + self.suspension_points.len();
381 
382                 // The resume arg target location might itself be remapped if its base local is
383                 // live across a yield.
384                 let resume_arg =
385                     if let Some(&(ty, variant, idx)) = self.remap.get(&resume_arg.local) {
386                         replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
387                         resume_arg
388                     } else {
389                         resume_arg
390                     };
391 
392                 self.suspension_points.push(SuspensionPoint {
393                     state,
394                     resume,
395                     resume_arg,
396                     drop,
397                     storage_liveness: self.storage_liveness[block].clone().unwrap().into(),
398                 });
399 
400                 VariantIdx::new(state)
401             } else {
402                 // Return
403                 VariantIdx::new(RETURNED) // state for returned
404             };
405             data.statements.push(self.set_discr(state, source_info));
406             data.terminator_mut().kind = TerminatorKind::Return;
407         }
408 
409         self.super_basic_block_data(block, data);
410     }
411 }
412 
make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>)413 fn make_generator_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
414     let gen_ty = body.local_decls.raw[1].ty;
415 
416     let ref_gen_ty = Ty::new_ref(
417         tcx,
418         tcx.lifetimes.re_erased,
419         ty::TypeAndMut { ty: gen_ty, mutbl: Mutability::Mut },
420     );
421 
422     // Replace the by value generator argument
423     body.local_decls.raw[1].ty = ref_gen_ty;
424 
425     // Add a deref to accesses of the generator state
426     DerefArgVisitor { tcx }.visit_body(body);
427 }
428 
make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>)429 fn make_generator_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
430     let ref_gen_ty = body.local_decls.raw[1].ty;
431 
432     let pin_did = tcx.require_lang_item(LangItem::Pin, Some(body.span));
433     let pin_adt_ref = tcx.adt_def(pin_did);
434     let substs = tcx.mk_substs(&[ref_gen_ty.into()]);
435     let pin_ref_gen_ty = Ty::new_adt(tcx, pin_adt_ref, substs);
436 
437     // Replace the by ref generator argument
438     body.local_decls.raw[1].ty = pin_ref_gen_ty;
439 
440     // Add the Pin field access to accesses of the generator state
441     PinArgVisitor { ref_gen_ty, tcx }.visit_body(body);
442 }
443 
444 /// Allocates a new local and replaces all references of `local` with it. Returns the new local.
445 ///
446 /// `local` will be changed to a new local decl with type `ty`.
447 ///
448 /// Note that the new local will be uninitialized. It is the caller's responsibility to assign some
449 /// valid value to it before its first use.
replace_local<'tcx>( local: Local, ty: Ty<'tcx>, body: &mut Body<'tcx>, tcx: TyCtxt<'tcx>, ) -> Local450 fn replace_local<'tcx>(
451     local: Local,
452     ty: Ty<'tcx>,
453     body: &mut Body<'tcx>,
454     tcx: TyCtxt<'tcx>,
455 ) -> Local {
456     let new_decl = LocalDecl::new(ty, body.span);
457     let new_local = body.local_decls.push(new_decl);
458     body.local_decls.swap(local, new_local);
459 
460     RenameLocalVisitor { from: local, to: new_local, tcx }.visit_body(body);
461 
462     new_local
463 }
464 
465 /// Transforms the `body` of the generator applying the following transforms:
466 ///
467 /// - Eliminates all the `get_context` calls that async lowering created.
468 /// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
469 ///
470 /// The `Local`s that have their types replaced are:
471 /// - The `resume` argument itself.
472 /// - The argument to `get_context`.
473 /// - The yielded value of a `yield`.
474 ///
475 /// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
476 /// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
477 ///
478 /// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
479 /// but rather directly use `&mut Context<'_>`, however that would currently
480 /// lead to higher-kinded lifetime errors.
481 /// See <https://github.com/rust-lang/rust/issues/105501>.
482 ///
483 /// The async lowering step and the type / lifetime inference / checking are
484 /// still using the `ResumeTy` indirection for the time being, and that indirection
485 /// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>)486 fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
487     let context_mut_ref = Ty::new_task_context(tcx);
488 
489     // replace the type of the `resume` argument
490     replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);
491 
492     let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);
493 
494     for bb in START_BLOCK..body.basic_blocks.next_index() {
495         let bb_data = &body[bb];
496         if bb_data.is_cleanup {
497             continue;
498         }
499 
500         match &bb_data.terminator().kind {
501             TerminatorKind::Call { func, .. } => {
502                 let func_ty = func.ty(body, tcx);
503                 if let ty::FnDef(def_id, _) = *func_ty.kind() {
504                     if def_id == get_context_def_id {
505                         let local = eliminate_get_context_call(&mut body[bb]);
506                         replace_resume_ty_local(tcx, body, local, context_mut_ref);
507                     }
508                 } else {
509                     continue;
510                 }
511             }
512             TerminatorKind::Yield { resume_arg, .. } => {
513                 replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
514             }
515             _ => {}
516         }
517     }
518 }
519 
eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local520 fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
521     let terminator = bb_data.terminator.take().unwrap();
522     if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
523         let arg = args.pop().unwrap();
524         let local = arg.place().unwrap().local;
525 
526         let arg = Rvalue::Use(arg);
527         let assign = Statement {
528             source_info: terminator.source_info,
529             kind: StatementKind::Assign(Box::new((destination, arg))),
530         };
531         bb_data.statements.push(assign);
532         bb_data.terminator = Some(Terminator {
533             source_info: terminator.source_info,
534             kind: TerminatorKind::Goto { target: target.unwrap() },
535         });
536         local
537     } else {
538         bug!();
539     }
540 }
541 
542 #[cfg_attr(not(debug_assertions), allow(unused))]
replace_resume_ty_local<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, local: Local, context_mut_ref: Ty<'tcx>, )543 fn replace_resume_ty_local<'tcx>(
544     tcx: TyCtxt<'tcx>,
545     body: &mut Body<'tcx>,
546     local: Local,
547     context_mut_ref: Ty<'tcx>,
548 ) {
549     let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
550     // We have to replace the `ResumeTy` that is used for type and borrow checking
551     // with `&mut Context<'_>` in MIR.
552     #[cfg(debug_assertions)]
553     {
554         if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
555             let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
556             assert_eq!(*resume_ty_adt, expected_adt);
557         } else {
558             panic!("expected `ResumeTy`, found `{:?}`", local_ty);
559         };
560     }
561 }
562 
563 struct LivenessInfo {
564     /// Which locals are live across any suspension point.
565     saved_locals: GeneratorSavedLocals,
566 
567     /// The set of saved locals live at each suspension point.
568     live_locals_at_suspension_points: Vec<BitSet<GeneratorSavedLocal>>,
569 
570     /// Parallel vec to the above with SourceInfo for each yield terminator.
571     source_info_at_suspension_points: Vec<SourceInfo>,
572 
573     /// For every saved local, the set of other saved locals that are
574     /// storage-live at the same time as this local. We cannot overlap locals in
575     /// the layout which have conflicting storage.
576     storage_conflicts: BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
577 
578     /// For every suspending block, the locals which are storage-live across
579     /// that suspension point.
580     storage_liveness: IndexVec<BasicBlock, Option<BitSet<Local>>>,
581 }
582 
locals_live_across_suspend_points<'tcx>( tcx: TyCtxt<'tcx>, body: &Body<'tcx>, always_live_locals: &BitSet<Local>, movable: bool, ) -> LivenessInfo583 fn locals_live_across_suspend_points<'tcx>(
584     tcx: TyCtxt<'tcx>,
585     body: &Body<'tcx>,
586     always_live_locals: &BitSet<Local>,
587     movable: bool,
588 ) -> LivenessInfo {
589     let body_ref: &Body<'_> = &body;
590 
591     // Calculate when MIR locals have live storage. This gives us an upper bound of their
592     // lifetimes.
593     let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals))
594         .into_engine(tcx, body_ref)
595         .iterate_to_fixpoint()
596         .into_results_cursor(body_ref);
597 
598     // Calculate the MIR locals which have been previously
599     // borrowed (even if they are still active).
600     let borrowed_locals_results =
601         MaybeBorrowedLocals.into_engine(tcx, body_ref).pass_name("generator").iterate_to_fixpoint();
602 
603     let mut borrowed_locals_cursor = borrowed_locals_results.cloned_results_cursor(body_ref);
604 
605     // Calculate the MIR locals that we actually need to keep storage around
606     // for.
607     let mut requires_storage_results =
608         MaybeRequiresStorage::new(borrowed_locals_results.cloned_results_cursor(body))
609             .into_engine(tcx, body_ref)
610             .iterate_to_fixpoint();
611     let mut requires_storage_cursor = requires_storage_results.as_results_cursor(body_ref);
612 
613     // Calculate the liveness of MIR locals ignoring borrows.
614     let mut liveness = MaybeLiveLocals
615         .into_engine(tcx, body_ref)
616         .pass_name("generator")
617         .iterate_to_fixpoint()
618         .into_results_cursor(body_ref);
619 
620     let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
621     let mut live_locals_at_suspension_points = Vec::new();
622     let mut source_info_at_suspension_points = Vec::new();
623     let mut live_locals_at_any_suspension_point = BitSet::new_empty(body.local_decls.len());
624 
625     for (block, data) in body.basic_blocks.iter_enumerated() {
626         if let TerminatorKind::Yield { .. } = data.terminator().kind {
627             let loc = Location { block, statement_index: data.statements.len() };
628 
629             liveness.seek_to_block_end(block);
630             let mut live_locals: BitSet<_> = BitSet::new_empty(body.local_decls.len());
631             live_locals.union(liveness.get());
632 
633             if !movable {
634                 // The `liveness` variable contains the liveness of MIR locals ignoring borrows.
635                 // This is correct for movable generators since borrows cannot live across
636                 // suspension points. However for immovable generators we need to account for
637                 // borrows, so we conservatively assume that all borrowed locals are live until
638                 // we find a StorageDead statement referencing the locals.
639                 // To do this we just union our `liveness` result with `borrowed_locals`, which
640                 // contains all the locals which has been borrowed before this suspension point.
641                 // If a borrow is converted to a raw reference, we must also assume that it lives
642                 // forever. Note that the final liveness is still bounded by the storage liveness
643                 // of the local, which happens using the `intersect` operation below.
644                 borrowed_locals_cursor.seek_before_primary_effect(loc);
645                 live_locals.union(borrowed_locals_cursor.get());
646             }
647 
648             // Store the storage liveness for later use so we can restore the state
649             // after a suspension point
650             storage_live.seek_before_primary_effect(loc);
651             storage_liveness_map[block] = Some(storage_live.get().clone());
652 
653             // Locals live are live at this point only if they are used across
654             // suspension points (the `liveness` variable)
655             // and their storage is required (the `storage_required` variable)
656             requires_storage_cursor.seek_before_primary_effect(loc);
657             live_locals.intersect(requires_storage_cursor.get());
658 
659             // The generator argument is ignored.
660             live_locals.remove(SELF_ARG);
661 
662             debug!("loc = {:?}, live_locals = {:?}", loc, live_locals);
663 
664             // Add the locals live at this suspension point to the set of locals which live across
665             // any suspension points
666             live_locals_at_any_suspension_point.union(&live_locals);
667 
668             live_locals_at_suspension_points.push(live_locals);
669             source_info_at_suspension_points.push(data.terminator().source_info);
670         }
671     }
672 
673     debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
674     let saved_locals = GeneratorSavedLocals(live_locals_at_any_suspension_point);
675 
676     // Renumber our liveness_map bitsets to include only the locals we are
677     // saving.
678     let live_locals_at_suspension_points = live_locals_at_suspension_points
679         .iter()
680         .map(|live_here| saved_locals.renumber_bitset(&live_here))
681         .collect();
682 
683     let storage_conflicts = compute_storage_conflicts(
684         body_ref,
685         &saved_locals,
686         always_live_locals.clone(),
687         requires_storage_results,
688     );
689 
690     LivenessInfo {
691         saved_locals,
692         live_locals_at_suspension_points,
693         source_info_at_suspension_points,
694         storage_conflicts,
695         storage_liveness: storage_liveness_map,
696     }
697 }
698 
699 /// The set of `Local`s that must be saved across yield points.
700 ///
701 /// `GeneratorSavedLocal` is indexed in terms of the elements in this set;
702 /// i.e. `GeneratorSavedLocal::new(1)` corresponds to the second local
703 /// included in this set.
704 struct GeneratorSavedLocals(BitSet<Local>);
705 
706 impl GeneratorSavedLocals {
707     /// Returns an iterator over each `GeneratorSavedLocal` along with the `Local` it corresponds
708     /// to.
iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)>709     fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (GeneratorSavedLocal, Local)> {
710         self.iter().enumerate().map(|(i, l)| (GeneratorSavedLocal::from(i), l))
711     }
712 
713     /// Transforms a `BitSet<Local>` that contains only locals saved across yield points to the
714     /// equivalent `BitSet<GeneratorSavedLocal>`.
renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal>715     fn renumber_bitset(&self, input: &BitSet<Local>) -> BitSet<GeneratorSavedLocal> {
716         assert!(self.superset(&input), "{:?} not a superset of {:?}", self.0, input);
717         let mut out = BitSet::new_empty(self.count());
718         for (saved_local, local) in self.iter_enumerated() {
719             if input.contains(local) {
720                 out.insert(saved_local);
721             }
722         }
723         out
724     }
725 
get(&self, local: Local) -> Option<GeneratorSavedLocal>726     fn get(&self, local: Local) -> Option<GeneratorSavedLocal> {
727         if !self.contains(local) {
728             return None;
729         }
730 
731         let idx = self.iter().take_while(|&l| l < local).count();
732         Some(GeneratorSavedLocal::new(idx))
733     }
734 }
735 
736 impl ops::Deref for GeneratorSavedLocals {
737     type Target = BitSet<Local>;
738 
deref(&self) -> &Self::Target739     fn deref(&self) -> &Self::Target {
740         &self.0
741     }
742 }
743 
744 /// For every saved local, looks for which locals are StorageLive at the same
745 /// time. Generates a bitset for every local of all the other locals that may be
746 /// StorageLive simultaneously with that local. This is used in the layout
747 /// computation; see `GeneratorLayout` for more.
compute_storage_conflicts<'mir, 'tcx>( body: &'mir Body<'tcx>, saved_locals: &GeneratorSavedLocals, always_live_locals: BitSet<Local>, mut requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'_, 'mir, 'tcx>>, ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>748 fn compute_storage_conflicts<'mir, 'tcx>(
749     body: &'mir Body<'tcx>,
750     saved_locals: &GeneratorSavedLocals,
751     always_live_locals: BitSet<Local>,
752     mut requires_storage: rustc_mir_dataflow::Results<'tcx, MaybeRequiresStorage<'_, 'mir, 'tcx>>,
753 ) -> BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal> {
754     assert_eq!(body.local_decls.len(), saved_locals.domain_size());
755 
756     debug!("compute_storage_conflicts({:?})", body.span);
757     debug!("always_live = {:?}", always_live_locals);
758 
759     // Locals that are always live or ones that need to be stored across
760     // suspension points are not eligible for overlap.
761     let mut ineligible_locals = always_live_locals;
762     ineligible_locals.intersect(&**saved_locals);
763 
764     // Compute the storage conflicts for all eligible locals.
765     let mut visitor = StorageConflictVisitor {
766         body,
767         saved_locals: &saved_locals,
768         local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
769     };
770 
771     requires_storage.visit_reachable_with(body, &mut visitor);
772 
773     let local_conflicts = visitor.local_conflicts;
774 
775     // Compress the matrix using only stored locals (Local -> GeneratorSavedLocal).
776     //
777     // NOTE: Today we store a full conflict bitset for every local. Technically
778     // this is twice as many bits as we need, since the relation is symmetric.
779     // However, in practice these bitsets are not usually large. The layout code
780     // also needs to keep track of how many conflicts each local has, so it's
781     // simpler to keep it this way for now.
782     let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
783     for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
784         if ineligible_locals.contains(local_a) {
785             // Conflicts with everything.
786             storage_conflicts.insert_all_into_row(saved_local_a);
787         } else {
788             // Keep overlap information only for stored locals.
789             for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
790                 if local_conflicts.contains(local_a, local_b) {
791                     storage_conflicts.insert(saved_local_a, saved_local_b);
792                 }
793             }
794         }
795     }
796     storage_conflicts
797 }
798 
799 struct StorageConflictVisitor<'mir, 'tcx, 's> {
800     body: &'mir Body<'tcx>,
801     saved_locals: &'s GeneratorSavedLocals,
802     // FIXME(tmandry): Consider using sparse bitsets here once we have good
803     // benchmarks for generators.
804     local_conflicts: BitMatrix<Local, Local>,
805 }
806 
807 impl<'mir, 'tcx, R> rustc_mir_dataflow::ResultsVisitor<'mir, 'tcx, R>
808     for StorageConflictVisitor<'mir, 'tcx, '_>
809 {
810     type FlowState = BitSet<Local>;
811 
visit_statement_before_primary_effect( &mut self, _results: &R, state: &Self::FlowState, _statement: &'mir Statement<'tcx>, loc: Location, )812     fn visit_statement_before_primary_effect(
813         &mut self,
814         _results: &R,
815         state: &Self::FlowState,
816         _statement: &'mir Statement<'tcx>,
817         loc: Location,
818     ) {
819         self.apply_state(state, loc);
820     }
821 
visit_terminator_before_primary_effect( &mut self, _results: &R, state: &Self::FlowState, _terminator: &'mir Terminator<'tcx>, loc: Location, )822     fn visit_terminator_before_primary_effect(
823         &mut self,
824         _results: &R,
825         state: &Self::FlowState,
826         _terminator: &'mir Terminator<'tcx>,
827         loc: Location,
828     ) {
829         self.apply_state(state, loc);
830     }
831 }
832 
833 impl StorageConflictVisitor<'_, '_, '_> {
apply_state(&mut self, flow_state: &BitSet<Local>, loc: Location)834     fn apply_state(&mut self, flow_state: &BitSet<Local>, loc: Location) {
835         // Ignore unreachable blocks.
836         if self.body.basic_blocks[loc.block].terminator().kind == TerminatorKind::Unreachable {
837             return;
838         }
839 
840         let mut eligible_storage_live = flow_state.clone();
841         eligible_storage_live.intersect(&**self.saved_locals);
842 
843         for local in eligible_storage_live.iter() {
844             self.local_conflicts.union_row_with(&eligible_storage_live, local);
845         }
846 
847         if eligible_storage_live.count() > 1 {
848             trace!("at {:?}, eligible_storage_live={:?}", loc, eligible_storage_live);
849         }
850     }
851 }
852 
853 /// Validates the typeck view of the generator against the actual set of types saved between
854 /// yield points.
sanitize_witness<'tcx>( tcx: TyCtxt<'tcx>, body: &Body<'tcx>, witness: Ty<'tcx>, upvars: Vec<Ty<'tcx>>, layout: &GeneratorLayout<'tcx>, )855 fn sanitize_witness<'tcx>(
856     tcx: TyCtxt<'tcx>,
857     body: &Body<'tcx>,
858     witness: Ty<'tcx>,
859     upvars: Vec<Ty<'tcx>>,
860     layout: &GeneratorLayout<'tcx>,
861 ) {
862     let did = body.source.def_id();
863     let param_env = tcx.param_env(did);
864 
865     let allowed_upvars = tcx.normalize_erasing_regions(param_env, upvars);
866     let allowed = match witness.kind() {
867         &ty::GeneratorWitness(interior_tys) => {
868             tcx.normalize_erasing_late_bound_regions(param_env, interior_tys)
869         }
870         _ => {
871             tcx.sess.delay_span_bug(
872                 body.span,
873                 format!("unexpected generator witness type {:?}", witness.kind()),
874             );
875             return;
876         }
877     };
878 
879     let mut mismatches = Vec::new();
880     for fty in &layout.field_tys {
881         if fty.ignore_for_traits {
882             continue;
883         }
884         let decl_ty = tcx.normalize_erasing_regions(param_env, fty.ty);
885 
886         // Sanity check that typeck knows about the type of locals which are
887         // live across a suspension point
888         if !allowed.contains(&decl_ty) && !allowed_upvars.contains(&decl_ty) {
889             mismatches.push(decl_ty);
890         }
891     }
892 
893     if !mismatches.is_empty() {
894         span_bug!(
895             body.span,
896             "Broken MIR: generator contains type {:?} in MIR, \
897                        but typeck only knows about {} and {:?}",
898             mismatches,
899             allowed,
900             allowed_upvars
901         );
902     }
903 }
904 
compute_layout<'tcx>( tcx: TyCtxt<'tcx>, liveness: LivenessInfo, body: &Body<'tcx>, ) -> ( FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>, GeneratorLayout<'tcx>, IndexVec<BasicBlock, Option<BitSet<Local>>>, )905 fn compute_layout<'tcx>(
906     tcx: TyCtxt<'tcx>,
907     liveness: LivenessInfo,
908     body: &Body<'tcx>,
909 ) -> (
910     FxHashMap<Local, (Ty<'tcx>, VariantIdx, FieldIdx)>,
911     GeneratorLayout<'tcx>,
912     IndexVec<BasicBlock, Option<BitSet<Local>>>,
913 ) {
914     let LivenessInfo {
915         saved_locals,
916         live_locals_at_suspension_points,
917         source_info_at_suspension_points,
918         storage_conflicts,
919         storage_liveness,
920     } = liveness;
921 
922     // Gather live local types and their indices.
923     let mut locals = IndexVec::<GeneratorSavedLocal, _>::new();
924     let mut tys = IndexVec::<GeneratorSavedLocal, _>::new();
925     for (saved_local, local) in saved_locals.iter_enumerated() {
926         debug!("generator saved local {:?} => {:?}", saved_local, local);
927 
928         locals.push(local);
929         let decl = &body.local_decls[local];
930         debug!(?decl);
931 
932         let ignore_for_traits = if tcx.sess.opts.unstable_opts.drop_tracking_mir {
933             // Do not `assert_crate_local` here, as post-borrowck cleanup may have already cleared
934             // the information. This is alright, since `ignore_for_traits` is only relevant when
935             // this code runs on pre-cleanup MIR, and `ignore_for_traits = false` is the safer
936             // default.
937             match decl.local_info {
938                 // Do not include raw pointers created from accessing `static` items, as those could
939                 // well be re-created by another access to the same static.
940                 ClearCrossCrate::Set(box LocalInfo::StaticRef { is_thread_local, .. }) => {
941                     !is_thread_local
942                 }
943                 // Fake borrows are only read by fake reads, so do not have any reality in
944                 // post-analysis MIR.
945                 ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true,
946                 _ => false,
947             }
948         } else {
949             // FIXME(#105084) HIR-based drop tracking does not account for all the temporaries that
950             // MIR building may introduce. This leads to wrongly ignored types, but this is
951             // necessary for internal consistency and to avoid ICEs.
952             decl.internal
953         };
954         let decl =
955             GeneratorSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
956         debug!(?decl);
957 
958         tys.push(decl);
959     }
960 
961     // Leave empty variants for the UNRESUMED, RETURNED, and POISONED states.
962     // In debuginfo, these will correspond to the beginning (UNRESUMED) or end
963     // (RETURNED, POISONED) of the function.
964     let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span;
965     let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = [
966         SourceInfo::outermost(body_span.shrink_to_lo()),
967         SourceInfo::outermost(body_span.shrink_to_hi()),
968         SourceInfo::outermost(body_span.shrink_to_hi()),
969     ]
970     .iter()
971     .copied()
972     .collect();
973 
974     // Build the generator variant field list.
975     // Create a map from local indices to generator struct indices.
976     let mut variant_fields: IndexVec<VariantIdx, IndexVec<FieldIdx, GeneratorSavedLocal>> =
977         iter::repeat(IndexVec::new()).take(RESERVED_VARIANTS).collect();
978     let mut remap = FxHashMap::default();
979     for (suspension_point_idx, live_locals) in live_locals_at_suspension_points.iter().enumerate() {
980         let variant_index = VariantIdx::from(RESERVED_VARIANTS + suspension_point_idx);
981         let mut fields = IndexVec::new();
982         for (idx, saved_local) in live_locals.iter().enumerate() {
983             fields.push(saved_local);
984             // Note that if a field is included in multiple variants, we will
985             // just use the first one here. That's fine; fields do not move
986             // around inside generators, so it doesn't matter which variant
987             // index we access them by.
988             let idx = FieldIdx::from_usize(idx);
989             remap.entry(locals[saved_local]).or_insert((tys[saved_local].ty, variant_index, idx));
990         }
991         variant_fields.push(fields);
992         variant_source_info.push(source_info_at_suspension_points[suspension_point_idx]);
993     }
994     debug!("generator variant_fields = {:?}", variant_fields);
995     debug!("generator storage_conflicts = {:#?}", storage_conflicts);
996 
997     let mut field_names = IndexVec::from_elem(None, &tys);
998     for var in &body.var_debug_info {
999         let VarDebugInfoContents::Place(place) = &var.value else { continue };
1000         let Some(local) = place.as_local() else { continue };
1001         let Some(&(_, variant, field)) = remap.get(&local) else { continue };
1002 
1003         let saved_local = variant_fields[variant][field];
1004         field_names.get_or_insert_with(saved_local, || var.name);
1005     }
1006 
1007     let layout = GeneratorLayout {
1008         field_tys: tys,
1009         field_names,
1010         variant_fields,
1011         variant_source_info,
1012         storage_conflicts,
1013     };
1014     debug!(?layout);
1015 
1016     (remap, layout, storage_liveness)
1017 }
1018 
1019 /// Replaces the entry point of `body` with a block that switches on the generator discriminant and
1020 /// dispatches to blocks according to `cases`.
1021 ///
1022 /// After this function, the former entry point of the function will be bb1.
insert_switch<'tcx>( body: &mut Body<'tcx>, cases: Vec<(usize, BasicBlock)>, transform: &TransformVisitor<'tcx>, default: TerminatorKind<'tcx>, )1023 fn insert_switch<'tcx>(
1024     body: &mut Body<'tcx>,
1025     cases: Vec<(usize, BasicBlock)>,
1026     transform: &TransformVisitor<'tcx>,
1027     default: TerminatorKind<'tcx>,
1028 ) {
1029     let default_block = insert_term_block(body, default);
1030     let (assign, discr) = transform.get_discr(body);
1031     let switch_targets =
1032         SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
1033     let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
1034 
1035     let source_info = SourceInfo::outermost(body.span);
1036     body.basic_blocks_mut().raw.insert(
1037         0,
1038         BasicBlockData {
1039             statements: vec![assign],
1040             terminator: Some(Terminator { source_info, kind: switch }),
1041             is_cleanup: false,
1042         },
1043     );
1044 
1045     let blocks = body.basic_blocks_mut().iter_mut();
1046 
1047     for target in blocks.flat_map(|b| b.terminator_mut().successors_mut()) {
1048         *target = BasicBlock::new(target.index() + 1);
1049     }
1050 }
1051 
elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>)1052 fn elaborate_generator_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1053     use crate::shim::DropShimElaborator;
1054     use rustc_middle::mir::patch::MirPatch;
1055     use rustc_mir_dataflow::elaborate_drops::{elaborate_drop, Unwind};
1056 
1057     // Note that `elaborate_drops` only drops the upvars of a generator, and
1058     // this is ok because `open_drop` can only be reached within that own
1059     // generator's resume function.
1060 
1061     let def_id = body.source.def_id();
1062     let param_env = tcx.param_env(def_id);
1063 
1064     let mut elaborator = DropShimElaborator { body, patch: MirPatch::new(body), tcx, param_env };
1065 
1066     for (block, block_data) in body.basic_blocks.iter_enumerated() {
1067         let (target, unwind, source_info) = match block_data.terminator() {
1068             Terminator {
1069                 source_info,
1070                 kind: TerminatorKind::Drop { place, target, unwind, replace: _ },
1071             } => {
1072                 if let Some(local) = place.as_local() {
1073                     if local == SELF_ARG {
1074                         (target, unwind, source_info)
1075                     } else {
1076                         continue;
1077                     }
1078                 } else {
1079                     continue;
1080                 }
1081             }
1082             _ => continue,
1083         };
1084         let unwind = if block_data.is_cleanup {
1085             Unwind::InCleanup
1086         } else {
1087             Unwind::To(match *unwind {
1088                 UnwindAction::Cleanup(tgt) => tgt,
1089                 UnwindAction::Continue => elaborator.patch.resume_block(),
1090                 UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
1091                 UnwindAction::Terminate => elaborator.patch.terminate_block(),
1092             })
1093         };
1094         elaborate_drop(
1095             &mut elaborator,
1096             *source_info,
1097             Place::from(SELF_ARG),
1098             (),
1099             *target,
1100             unwind,
1101             block,
1102         );
1103     }
1104     elaborator.patch.apply(body);
1105 }
1106 
create_generator_drop_shim<'tcx>( tcx: TyCtxt<'tcx>, transform: &TransformVisitor<'tcx>, gen_ty: Ty<'tcx>, body: &mut Body<'tcx>, drop_clean: BasicBlock, ) -> Body<'tcx>1107 fn create_generator_drop_shim<'tcx>(
1108     tcx: TyCtxt<'tcx>,
1109     transform: &TransformVisitor<'tcx>,
1110     gen_ty: Ty<'tcx>,
1111     body: &mut Body<'tcx>,
1112     drop_clean: BasicBlock,
1113 ) -> Body<'tcx> {
1114     let mut body = body.clone();
1115     body.arg_count = 1; // make sure the resume argument is not included here
1116 
1117     let source_info = SourceInfo::outermost(body.span);
1118 
1119     let mut cases = create_cases(&mut body, transform, Operation::Drop);
1120 
1121     cases.insert(0, (UNRESUMED, drop_clean));
1122 
1123     // The returned state and the poisoned state fall through to the default
1124     // case which is just to return
1125 
1126     insert_switch(&mut body, cases, &transform, TerminatorKind::Return);
1127 
1128     for block in body.basic_blocks_mut() {
1129         let kind = &mut block.terminator_mut().kind;
1130         if let TerminatorKind::GeneratorDrop = *kind {
1131             *kind = TerminatorKind::Return;
1132         }
1133     }
1134 
1135     // Replace the return variable
1136     body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(Ty::new_unit(tcx), source_info);
1137 
1138     make_generator_state_argument_indirect(tcx, &mut body);
1139 
1140     // Change the generator argument from &mut to *mut
1141     body.local_decls[SELF_ARG] = LocalDecl::with_source_info(
1142         Ty::new_ptr(tcx, ty::TypeAndMut { ty: gen_ty, mutbl: hir::Mutability::Mut }),
1143         source_info,
1144     );
1145 
1146     // Make sure we remove dead blocks to remove
1147     // unrelated code from the resume part of the function
1148     simplify::remove_dead_blocks(tcx, &mut body);
1149 
1150     dump_mir(tcx, false, "generator_drop", &0, &body, |_, _| Ok(()));
1151 
1152     body
1153 }
1154 
insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock1155 fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
1156     let source_info = SourceInfo::outermost(body.span);
1157     body.basic_blocks_mut().push(BasicBlockData {
1158         statements: Vec::new(),
1159         terminator: Some(Terminator { source_info, kind }),
1160         is_cleanup: false,
1161     })
1162 }
1163 
insert_panic_block<'tcx>( tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>, message: AssertMessage<'tcx>, ) -> BasicBlock1164 fn insert_panic_block<'tcx>(
1165     tcx: TyCtxt<'tcx>,
1166     body: &mut Body<'tcx>,
1167     message: AssertMessage<'tcx>,
1168 ) -> BasicBlock {
1169     let assert_block = BasicBlock::new(body.basic_blocks.len());
1170     let term = TerminatorKind::Assert {
1171         cond: Operand::Constant(Box::new(Constant {
1172             span: body.span,
1173             user_ty: None,
1174             literal: ConstantKind::from_bool(tcx, false),
1175         })),
1176         expected: true,
1177         msg: Box::new(message),
1178         target: assert_block,
1179         unwind: UnwindAction::Continue,
1180     };
1181 
1182     let source_info = SourceInfo::outermost(body.span);
1183     body.basic_blocks_mut().push(BasicBlockData {
1184         statements: Vec::new(),
1185         terminator: Some(Terminator { source_info, kind: term }),
1186         is_cleanup: false,
1187     });
1188 
1189     assert_block
1190 }
1191 
can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool1192 fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, param_env: ty::ParamEnv<'tcx>) -> bool {
1193     // Returning from a function with an uninhabited return type is undefined behavior.
1194     if body.return_ty().is_privately_uninhabited(tcx, param_env) {
1195         return false;
1196     }
1197 
1198     // If there's a return terminator the function may return.
1199     for block in body.basic_blocks.iter() {
1200         if let TerminatorKind::Return = block.terminator().kind {
1201             return true;
1202         }
1203     }
1204 
1205     // Otherwise the function can't return.
1206     false
1207 }
1208 
can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool1209 fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
1210     // Nothing can unwind when landing pads are off.
1211     if tcx.sess.panic_strategy() == PanicStrategy::Abort {
1212         return false;
1213     }
1214 
1215     // Unwinds can only start at certain terminators.
1216     for block in body.basic_blocks.iter() {
1217         match block.terminator().kind {
1218             // These never unwind.
1219             TerminatorKind::Goto { .. }
1220             | TerminatorKind::SwitchInt { .. }
1221             | TerminatorKind::Terminate
1222             | TerminatorKind::Return
1223             | TerminatorKind::Unreachable
1224             | TerminatorKind::GeneratorDrop
1225             | TerminatorKind::FalseEdge { .. }
1226             | TerminatorKind::FalseUnwind { .. } => {}
1227 
1228             // Resume will *continue* unwinding, but if there's no other unwinding terminator it
1229             // will never be reached.
1230             TerminatorKind::Resume => {}
1231 
1232             TerminatorKind::Yield { .. } => {
1233                 unreachable!("`can_unwind` called before generator transform")
1234             }
1235 
1236             // These may unwind.
1237             TerminatorKind::Drop { .. }
1238             | TerminatorKind::Call { .. }
1239             | TerminatorKind::InlineAsm { .. }
1240             | TerminatorKind::Assert { .. } => return true,
1241         }
1242     }
1243 
1244     // If we didn't find an unwinding terminator, the function cannot unwind.
1245     false
1246 }
1247 
create_generator_resume_function<'tcx>( tcx: TyCtxt<'tcx>, transform: TransformVisitor<'tcx>, body: &mut Body<'tcx>, can_return: bool, )1248 fn create_generator_resume_function<'tcx>(
1249     tcx: TyCtxt<'tcx>,
1250     transform: TransformVisitor<'tcx>,
1251     body: &mut Body<'tcx>,
1252     can_return: bool,
1253 ) {
1254     let can_unwind = can_unwind(tcx, body);
1255 
1256     // Poison the generator when it unwinds
1257     if can_unwind {
1258         let source_info = SourceInfo::outermost(body.span);
1259         let poison_block = body.basic_blocks_mut().push(BasicBlockData {
1260             statements: vec![transform.set_discr(VariantIdx::new(POISONED), source_info)],
1261             terminator: Some(Terminator { source_info, kind: TerminatorKind::Resume }),
1262             is_cleanup: true,
1263         });
1264 
1265         for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
1266             let source_info = block.terminator().source_info;
1267 
1268             if let TerminatorKind::Resume = block.terminator().kind {
1269                 // An existing `Resume` terminator is redirected to jump to our dedicated
1270                 // "poisoning block" above.
1271                 if idx != poison_block {
1272                     *block.terminator_mut() = Terminator {
1273                         source_info,
1274                         kind: TerminatorKind::Goto { target: poison_block },
1275                     };
1276                 }
1277             } else if !block.is_cleanup {
1278                 // Any terminators that *can* unwind but don't have an unwind target set are also
1279                 // pointed at our poisoning block (unless they're part of the cleanup path).
1280                 if let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut() {
1281                     *unwind = UnwindAction::Cleanup(poison_block);
1282                 }
1283             }
1284         }
1285     }
1286 
1287     let mut cases = create_cases(body, &transform, Operation::Resume);
1288 
1289     use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
1290 
1291     // Jump to the entry point on the unresumed
1292     cases.insert(0, (UNRESUMED, START_BLOCK));
1293 
1294     // Panic when resumed on the returned or poisoned state
1295     let generator_kind = body.generator_kind().unwrap();
1296 
1297     if can_unwind {
1298         cases.insert(
1299             1,
1300             (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(generator_kind))),
1301         );
1302     }
1303 
1304     if can_return {
1305         cases.insert(
1306             1,
1307             (RETURNED, insert_panic_block(tcx, body, ResumedAfterReturn(generator_kind))),
1308         );
1309     }
1310 
1311     insert_switch(body, cases, &transform, TerminatorKind::Unreachable);
1312 
1313     make_generator_state_argument_indirect(tcx, body);
1314     make_generator_state_argument_pinned(tcx, body);
1315 
1316     // Make sure we remove dead blocks to remove
1317     // unrelated code from the drop part of the function
1318     simplify::remove_dead_blocks(tcx, body);
1319 
1320     dump_mir(tcx, false, "generator_resume", &0, body, |_, _| Ok(()));
1321 }
1322 
insert_clean_drop(body: &mut Body<'_>) -> BasicBlock1323 fn insert_clean_drop(body: &mut Body<'_>) -> BasicBlock {
1324     let return_block = insert_term_block(body, TerminatorKind::Return);
1325 
1326     let term = TerminatorKind::Drop {
1327         place: Place::from(SELF_ARG),
1328         target: return_block,
1329         unwind: UnwindAction::Continue,
1330         replace: false,
1331     };
1332     let source_info = SourceInfo::outermost(body.span);
1333 
1334     // Create a block to destroy an unresumed generators. This can only destroy upvars.
1335     body.basic_blocks_mut().push(BasicBlockData {
1336         statements: Vec::new(),
1337         terminator: Some(Terminator { source_info, kind: term }),
1338         is_cleanup: false,
1339     })
1340 }
1341 
1342 /// An operation that can be performed on a generator.
1343 #[derive(PartialEq, Copy, Clone)]
1344 enum Operation {
1345     Resume,
1346     Drop,
1347 }
1348 
1349 impl Operation {
target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock>1350     fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
1351         match self {
1352             Operation::Resume => Some(point.resume),
1353             Operation::Drop => point.drop,
1354         }
1355     }
1356 }
1357 
create_cases<'tcx>( body: &mut Body<'tcx>, transform: &TransformVisitor<'tcx>, operation: Operation, ) -> Vec<(usize, BasicBlock)>1358 fn create_cases<'tcx>(
1359     body: &mut Body<'tcx>,
1360     transform: &TransformVisitor<'tcx>,
1361     operation: Operation,
1362 ) -> Vec<(usize, BasicBlock)> {
1363     let source_info = SourceInfo::outermost(body.span);
1364 
1365     transform
1366         .suspension_points
1367         .iter()
1368         .filter_map(|point| {
1369             // Find the target for this suspension point, if applicable
1370             operation.target_block(point).map(|target| {
1371                 let mut statements = Vec::new();
1372 
1373                 // Create StorageLive instructions for locals with live storage
1374                 for i in 0..(body.local_decls.len()) {
1375                     if i == 2 {
1376                         // The resume argument is live on function entry. Don't insert a
1377                         // `StorageLive`, or the following `Assign` will read from uninitialized
1378                         // memory.
1379                         continue;
1380                     }
1381 
1382                     let l = Local::new(i);
1383                     let needs_storage_live = point.storage_liveness.contains(l)
1384                         && !transform.remap.contains_key(&l)
1385                         && !transform.always_live_locals.contains(l);
1386                     if needs_storage_live {
1387                         statements
1388                             .push(Statement { source_info, kind: StatementKind::StorageLive(l) });
1389                     }
1390                 }
1391 
1392                 if operation == Operation::Resume {
1393                     // Move the resume argument to the destination place of the `Yield` terminator
1394                     let resume_arg = Local::new(2); // 0 = return, 1 = self
1395                     statements.push(Statement {
1396                         source_info,
1397                         kind: StatementKind::Assign(Box::new((
1398                             point.resume_arg,
1399                             Rvalue::Use(Operand::Move(resume_arg.into())),
1400                         ))),
1401                     });
1402                 }
1403 
1404                 // Then jump to the real target
1405                 let block = body.basic_blocks_mut().push(BasicBlockData {
1406                     statements,
1407                     terminator: Some(Terminator {
1408                         source_info,
1409                         kind: TerminatorKind::Goto { target },
1410                     }),
1411                     is_cleanup: false,
1412                 });
1413 
1414                 (point.state, block)
1415             })
1416         })
1417         .collect()
1418 }
1419 
1420 #[instrument(level = "debug", skip(tcx), ret)]
mir_generator_witnesses<'tcx>( tcx: TyCtxt<'tcx>, def_id: LocalDefId, ) -> Option<GeneratorLayout<'tcx>>1421 pub(crate) fn mir_generator_witnesses<'tcx>(
1422     tcx: TyCtxt<'tcx>,
1423     def_id: LocalDefId,
1424 ) -> Option<GeneratorLayout<'tcx>> {
1425     assert!(tcx.sess.opts.unstable_opts.drop_tracking_mir);
1426 
1427     let (body, _) = tcx.mir_promoted(def_id);
1428     let body = body.borrow();
1429     let body = &*body;
1430 
1431     // The first argument is the generator type passed by value
1432     let gen_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
1433 
1434     // Get the interior types and substs which typeck computed
1435     let movable = match *gen_ty.kind() {
1436         ty::Generator(_, _, movability) => movability == hir::Movability::Movable,
1437         ty::Error(_) => return None,
1438         _ => span_bug!(body.span, "unexpected generator type {}", gen_ty),
1439     };
1440 
1441     // When first entering the generator, move the resume argument into its new local.
1442     let always_live_locals = always_storage_live_locals(&body);
1443 
1444     let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1445 
1446     // Extract locals which are live across suspension point into `layout`
1447     // `remap` gives a mapping from local indices onto generator struct indices
1448     // `storage_liveness` tells us which locals have live storage at suspension points
1449     let (_, generator_layout, _) = compute_layout(tcx, liveness_info, body);
1450 
1451     check_suspend_tys(tcx, &generator_layout, &body);
1452 
1453     Some(generator_layout)
1454 }
1455 
1456 impl<'tcx> MirPass<'tcx> for StateTransform {
run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>)1457     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1458         let Some(yield_ty) = body.yield_ty() else {
1459             // This only applies to generators
1460             return;
1461         };
1462 
1463         assert!(body.generator_drop().is_none());
1464 
1465         // The first argument is the generator type passed by value
1466         let gen_ty = body.local_decls.raw[1].ty;
1467 
1468         // Get the discriminant type and substs which typeck computed
1469         let (discr_ty, upvars, interior, movable) = match *gen_ty.kind() {
1470             ty::Generator(_, substs, movability) => {
1471                 let substs = substs.as_generator();
1472                 (
1473                     substs.discr_ty(tcx),
1474                     substs.upvar_tys().collect::<Vec<_>>(),
1475                     substs.witness(),
1476                     movability == hir::Movability::Movable,
1477                 )
1478             }
1479             _ => {
1480                 tcx.sess.delay_span_bug(body.span, format!("unexpected generator type {}", gen_ty));
1481                 return;
1482             }
1483         };
1484 
1485         let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
1486         let (state_adt_ref, state_substs) = if is_async_kind {
1487             // Compute Poll<return_ty>
1488             let poll_did = tcx.require_lang_item(LangItem::Poll, None);
1489             let poll_adt_ref = tcx.adt_def(poll_did);
1490             let poll_substs = tcx.mk_substs(&[body.return_ty().into()]);
1491             (poll_adt_ref, poll_substs)
1492         } else {
1493             // Compute GeneratorState<yield_ty, return_ty>
1494             let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
1495             let state_adt_ref = tcx.adt_def(state_did);
1496             let state_substs = tcx.mk_substs(&[yield_ty.into(), body.return_ty().into()]);
1497             (state_adt_ref, state_substs)
1498         };
1499         let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_substs);
1500 
1501         // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1502         // RETURN_PLACE then is a fresh unused local with type ret_ty.
1503         let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);
1504 
1505         // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1506         if is_async_kind {
1507             transform_async_context(tcx, body);
1508         }
1509 
1510         // We also replace the resume argument and insert an `Assign`.
1511         // This is needed because the resume argument `_2` might be live across a `yield`, in which
1512         // case there is no `Assign` to it that the transform can turn into a store to the generator
1513         // state. After the yield the slot in the generator state would then be uninitialized.
1514         let resume_local = Local::new(2);
1515         let resume_ty = if is_async_kind {
1516             Ty::new_task_context(tcx)
1517         } else {
1518             body.local_decls[resume_local].ty
1519         };
1520         let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);
1521 
1522         // When first entering the generator, move the resume argument into its new local.
1523         let source_info = SourceInfo::outermost(body.span);
1524         let stmts = &mut body.basic_blocks_mut()[START_BLOCK].statements;
1525         stmts.insert(
1526             0,
1527             Statement {
1528                 source_info,
1529                 kind: StatementKind::Assign(Box::new((
1530                     new_resume_local.into(),
1531                     Rvalue::Use(Operand::Move(resume_local.into())),
1532                 ))),
1533             },
1534         );
1535 
1536         let always_live_locals = always_storage_live_locals(&body);
1537 
1538         let liveness_info =
1539             locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1540 
1541         if tcx.sess.opts.unstable_opts.validate_mir {
1542             let mut vis = EnsureGeneratorFieldAssignmentsNeverAlias {
1543                 assigned_local: None,
1544                 saved_locals: &liveness_info.saved_locals,
1545                 storage_conflicts: &liveness_info.storage_conflicts,
1546             };
1547 
1548             vis.visit_body(body);
1549         }
1550 
1551         // Extract locals which are live across suspension point into `layout`
1552         // `remap` gives a mapping from local indices onto generator struct indices
1553         // `storage_liveness` tells us which locals have live storage at suspension points
1554         let (remap, layout, storage_liveness) = compute_layout(tcx, liveness_info, body);
1555 
1556         if tcx.sess.opts.unstable_opts.validate_mir
1557             && !tcx.sess.opts.unstable_opts.drop_tracking_mir
1558         {
1559             sanitize_witness(tcx, body, interior, upvars, &layout);
1560         }
1561 
1562         let can_return = can_return(tcx, body, tcx.param_env(body.source.def_id()));
1563 
1564         // Run the transformation which converts Places from Local to generator struct
1565         // accesses for locals in `remap`.
1566         // It also rewrites `return x` and `yield y` as writing a new generator state and returning
1567         // either GeneratorState::Complete(x) and GeneratorState::Yielded(y),
1568         // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
1569         let mut transform = TransformVisitor {
1570             tcx,
1571             is_async_kind,
1572             state_adt_ref,
1573             state_substs,
1574             remap,
1575             storage_liveness,
1576             always_live_locals,
1577             suspension_points: Vec::new(),
1578             new_ret_local,
1579             discr_ty,
1580         };
1581         transform.visit_body(body);
1582 
1583         // Update our MIR struct to reflect the changes we've made
1584         body.arg_count = 2; // self, resume arg
1585         body.spread_arg = None;
1586 
1587         // The original arguments to the function are no longer arguments, mark them as such.
1588         // Otherwise they'll conflict with our new arguments, which although they don't have
1589         // argument_index set, will get emitted as unnamed arguments.
1590         for var in &mut body.var_debug_info {
1591             var.argument_index = None;
1592         }
1593 
1594         body.generator.as_mut().unwrap().yield_ty = None;
1595         body.generator.as_mut().unwrap().generator_layout = Some(layout);
1596 
1597         // Insert `drop(generator_struct)` which is used to drop upvars for generators in
1598         // the unresumed state.
1599         // This is expanded to a drop ladder in `elaborate_generator_drops`.
1600         let drop_clean = insert_clean_drop(body);
1601 
1602         dump_mir(tcx, false, "generator_pre-elab", &0, body, |_, _| Ok(()));
1603 
1604         // Expand `drop(generator_struct)` to a drop ladder which destroys upvars.
1605         // If any upvars are moved out of, drop elaboration will handle upvar destruction.
1606         // However we need to also elaborate the code generated by `insert_clean_drop`.
1607         elaborate_generator_drops(tcx, body);
1608 
1609         dump_mir(tcx, false, "generator_post-transform", &0, body, |_, _| Ok(()));
1610 
1611         // Create a copy of our MIR and use it to create the drop shim for the generator
1612         let drop_shim = create_generator_drop_shim(tcx, &transform, gen_ty, body, drop_clean);
1613 
1614         body.generator.as_mut().unwrap().generator_drop = Some(drop_shim);
1615 
1616         // Create the Generator::resume / Future::poll function
1617         create_generator_resume_function(tcx, transform, body, can_return);
1618 
1619         // Run derefer to fix Derefs that are not in the first place
1620         deref_finder(tcx, body);
1621     }
1622 }
1623 
1624 /// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
1625 /// in the generator state machine but whose storage is not marked as conflicting
1626 ///
1627 /// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
1628 ///
1629 /// This condition would arise when the assignment is the last use of `_5` but the initial
1630 /// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
1631 /// conflicting. Non-conflicting generator saved locals may be stored at the same location within
1632 /// the generator state machine, which would result in ill-formed MIR: the left-hand and right-hand
1633 /// sides of an assignment may not alias. This caused a miscompilation in [#73137].
1634 ///
1635 /// [#73137]: https://github.com/rust-lang/rust/issues/73137
1636 struct EnsureGeneratorFieldAssignmentsNeverAlias<'a> {
1637     saved_locals: &'a GeneratorSavedLocals,
1638     storage_conflicts: &'a BitMatrix<GeneratorSavedLocal, GeneratorSavedLocal>,
1639     assigned_local: Option<GeneratorSavedLocal>,
1640 }
1641 
1642 impl EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal>1643     fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<GeneratorSavedLocal> {
1644         if place.is_indirect() {
1645             return None;
1646         }
1647 
1648         self.saved_locals.get(place.local)
1649     }
1650 
check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self))1651     fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1652         if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1653             assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1654 
1655             self.assigned_local = Some(assigned_local);
1656             f(self);
1657             self.assigned_local = None;
1658         }
1659     }
1660 }
1661 
1662 impl<'tcx> Visitor<'tcx> for EnsureGeneratorFieldAssignmentsNeverAlias<'_> {
visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location)1663     fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1664         let Some(lhs) = self.assigned_local else {
1665             // This visitor only invokes `visit_place` for the right-hand side of an assignment
1666             // and only after setting `self.assigned_local`. However, the default impl of
1667             // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places
1668             // with debuginfo. Ignore them here.
1669             assert!(!context.is_use());
1670             return;
1671         };
1672 
1673         let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1674 
1675         if !self.storage_conflicts.contains(lhs, rhs) {
1676             bug!(
1677                 "Assignment between generator saved locals whose storage is not \
1678                     marked as conflicting: {:?}: {:?} = {:?}",
1679                 location,
1680                 lhs,
1681                 rhs,
1682             );
1683         }
1684     }
1685 
visit_statement(&mut self, statement: &Statement<'tcx>, location: Location)1686     fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1687         match &statement.kind {
1688             StatementKind::Assign(box (lhs, rhs)) => {
1689                 self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1690             }
1691 
1692             StatementKind::FakeRead(..)
1693             | StatementKind::SetDiscriminant { .. }
1694             | StatementKind::Deinit(..)
1695             | StatementKind::StorageLive(_)
1696             | StatementKind::StorageDead(_)
1697             | StatementKind::Retag(..)
1698             | StatementKind::AscribeUserType(..)
1699             | StatementKind::PlaceMention(..)
1700             | StatementKind::Coverage(..)
1701             | StatementKind::Intrinsic(..)
1702             | StatementKind::ConstEvalCounter
1703             | StatementKind::Nop => {}
1704         }
1705     }
1706 
visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location)1707     fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1708         // Checking for aliasing in terminators is probably overkill, but until we have actual
1709         // semantics, we should be conservative here.
1710         match &terminator.kind {
1711             TerminatorKind::Call {
1712                 func,
1713                 args,
1714                 destination,
1715                 target: Some(_),
1716                 unwind: _,
1717                 call_source: _,
1718                 fn_span: _,
1719             } => {
1720                 self.check_assigned_place(*destination, |this| {
1721                     this.visit_operand(func, location);
1722                     for arg in args {
1723                         this.visit_operand(arg, location);
1724                     }
1725                 });
1726             }
1727 
1728             TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1729                 self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1730             }
1731 
1732             // FIXME: Does `asm!` have any aliasing requirements?
1733             TerminatorKind::InlineAsm { .. } => {}
1734 
1735             TerminatorKind::Call { .. }
1736             | TerminatorKind::Goto { .. }
1737             | TerminatorKind::SwitchInt { .. }
1738             | TerminatorKind::Resume
1739             | TerminatorKind::Terminate
1740             | TerminatorKind::Return
1741             | TerminatorKind::Unreachable
1742             | TerminatorKind::Drop { .. }
1743             | TerminatorKind::Assert { .. }
1744             | TerminatorKind::GeneratorDrop
1745             | TerminatorKind::FalseEdge { .. }
1746             | TerminatorKind::FalseUnwind { .. } => {}
1747         }
1748     }
1749 }
1750 
check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>)1751 fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &GeneratorLayout<'tcx>, body: &Body<'tcx>) {
1752     let mut linted_tys = FxHashSet::default();
1753 
1754     // We want a user-facing param-env.
1755     let param_env = tcx.param_env(body.source.def_id());
1756 
1757     for (variant, yield_source_info) in
1758         layout.variant_fields.iter().zip(&layout.variant_source_info)
1759     {
1760         debug!(?variant);
1761         for &local in variant {
1762             let decl = &layout.field_tys[local];
1763             debug!(?decl);
1764 
1765             if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
1766                 let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else { continue };
1767 
1768                 check_must_not_suspend_ty(
1769                     tcx,
1770                     decl.ty,
1771                     hir_id,
1772                     param_env,
1773                     SuspendCheckData {
1774                         source_span: decl.source_info.span,
1775                         yield_span: yield_source_info.span,
1776                         plural_len: 1,
1777                         ..Default::default()
1778                     },
1779                 );
1780             }
1781         }
1782     }
1783 }
1784 
1785 #[derive(Default)]
1786 struct SuspendCheckData<'a> {
1787     source_span: Span,
1788     yield_span: Span,
1789     descr_pre: &'a str,
1790     descr_post: &'a str,
1791     plural_len: usize,
1792 }
1793 
1794 // Returns whether it emitted a diagnostic or not
1795 // Note that this fn and the proceeding one are based on the code
1796 // for creating must_use diagnostics
1797 //
1798 // Note that this technique was chosen over things like a `Suspend` marker trait
1799 // as it is simpler and has precedent in the compiler
check_must_not_suspend_ty<'tcx>( tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, hir_id: hir::HirId, param_env: ty::ParamEnv<'tcx>, data: SuspendCheckData<'_>, ) -> bool1800 fn check_must_not_suspend_ty<'tcx>(
1801     tcx: TyCtxt<'tcx>,
1802     ty: Ty<'tcx>,
1803     hir_id: hir::HirId,
1804     param_env: ty::ParamEnv<'tcx>,
1805     data: SuspendCheckData<'_>,
1806 ) -> bool {
1807     if ty.is_unit() {
1808         return false;
1809     }
1810 
1811     let plural_suffix = pluralize!(data.plural_len);
1812 
1813     debug!("Checking must_not_suspend for {}", ty);
1814 
1815     match *ty.kind() {
1816         ty::Adt(..) if ty.is_box() => {
1817             let boxed_ty = ty.boxed_ty();
1818             let descr_pre = &format!("{}boxed ", data.descr_pre);
1819             check_must_not_suspend_ty(
1820                 tcx,
1821                 boxed_ty,
1822                 hir_id,
1823                 param_env,
1824                 SuspendCheckData { descr_pre, ..data },
1825             )
1826         }
1827         ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
1828         // FIXME: support adding the attribute to TAITs
1829         ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
1830             let mut has_emitted = false;
1831             for &(predicate, _) in tcx.explicit_item_bounds(def).skip_binder() {
1832                 // We only look at the `DefId`, so it is safe to skip the binder here.
1833                 if let ty::ClauseKind::Trait(ref poly_trait_predicate) =
1834                     predicate.kind().skip_binder()
1835                 {
1836                     let def_id = poly_trait_predicate.trait_ref.def_id;
1837                     let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
1838                     if check_must_not_suspend_def(
1839                         tcx,
1840                         def_id,
1841                         hir_id,
1842                         SuspendCheckData { descr_pre, ..data },
1843                     ) {
1844                         has_emitted = true;
1845                         break;
1846                     }
1847                 }
1848             }
1849             has_emitted
1850         }
1851         ty::Dynamic(binder, _, _) => {
1852             let mut has_emitted = false;
1853             for predicate in binder.iter() {
1854                 if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
1855                     let def_id = trait_ref.def_id;
1856                     let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
1857                     if check_must_not_suspend_def(
1858                         tcx,
1859                         def_id,
1860                         hir_id,
1861                         SuspendCheckData { descr_post, ..data },
1862                     ) {
1863                         has_emitted = true;
1864                         break;
1865                     }
1866                 }
1867             }
1868             has_emitted
1869         }
1870         ty::Tuple(fields) => {
1871             let mut has_emitted = false;
1872             for (i, ty) in fields.iter().enumerate() {
1873                 let descr_post = &format!(" in tuple element {i}");
1874                 if check_must_not_suspend_ty(
1875                     tcx,
1876                     ty,
1877                     hir_id,
1878                     param_env,
1879                     SuspendCheckData { descr_post, ..data },
1880                 ) {
1881                     has_emitted = true;
1882                 }
1883             }
1884             has_emitted
1885         }
1886         ty::Array(ty, len) => {
1887             let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
1888             check_must_not_suspend_ty(
1889                 tcx,
1890                 ty,
1891                 hir_id,
1892                 param_env,
1893                 SuspendCheckData {
1894                     descr_pre,
1895                     plural_len: len.try_eval_target_usize(tcx, param_env).unwrap_or(0) as usize + 1,
1896                     ..data
1897                 },
1898             )
1899         }
1900         // If drop tracking is enabled, we want to look through references, since the referent
1901         // may not be considered live across the await point.
1902         ty::Ref(_region, ty, _mutability) => {
1903             let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
1904             check_must_not_suspend_ty(
1905                 tcx,
1906                 ty,
1907                 hir_id,
1908                 param_env,
1909                 SuspendCheckData { descr_pre, ..data },
1910             )
1911         }
1912         _ => false,
1913     }
1914 }
1915 
check_must_not_suspend_def( tcx: TyCtxt<'_>, def_id: DefId, hir_id: hir::HirId, data: SuspendCheckData<'_>, ) -> bool1916 fn check_must_not_suspend_def(
1917     tcx: TyCtxt<'_>,
1918     def_id: DefId,
1919     hir_id: hir::HirId,
1920     data: SuspendCheckData<'_>,
1921 ) -> bool {
1922     if let Some(attr) = tcx.get_attr(def_id, sym::must_not_suspend) {
1923         let reason = attr.value_str().map(|s| errors::MustNotSuspendReason {
1924             span: data.source_span,
1925             reason: s.as_str().to_string(),
1926         });
1927         tcx.emit_spanned_lint(
1928             rustc_session::lint::builtin::MUST_NOT_SUSPEND,
1929             hir_id,
1930             data.source_span,
1931             errors::MustNotSupend {
1932                 yield_sp: data.yield_span,
1933                 reason,
1934                 src_sp: data.source_span,
1935                 pre: data.descr_pre,
1936                 def_path: tcx.def_path_str(def_id),
1937                 post: data.descr_post,
1938             },
1939         );
1940 
1941         true
1942     } else {
1943         false
1944     }
1945 }
1946