• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Name resolution for expressions.
2 use hir_expand::name::Name;
3 use la_arena::{Arena, Idx, IdxRange, RawIdx};
4 use rustc_hash::FxHashMap;
5 use triomphe::Arc;
6 
7 use crate::{
8     body::Body,
9     db::DefDatabase,
10     hir::{Binding, BindingId, Expr, ExprId, LabelId, Pat, PatId, Statement},
11     BlockId, DefWithBodyId,
12 };
13 
14 pub type ScopeId = Idx<ScopeData>;
15 
16 #[derive(Debug, PartialEq, Eq)]
17 pub struct ExprScopes {
18     scopes: Arena<ScopeData>,
19     scope_entries: Arena<ScopeEntry>,
20     scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22 
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25     name: Name,
26     binding: BindingId,
27 }
28 
29 impl ScopeEntry {
name(&self) -> &Name30     pub fn name(&self) -> &Name {
31         &self.name
32     }
33 
binding(&self) -> BindingId34     pub fn binding(&self) -> BindingId {
35         self.binding
36     }
37 }
38 
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41     parent: Option<ScopeId>,
42     block: Option<BlockId>,
43     label: Option<(LabelId, Name)>,
44     entries: IdxRange<ScopeEntry>,
45 }
46 
47 impl ExprScopes {
expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes>48     pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
49         let body = db.body(def);
50         let mut scopes = ExprScopes::new(&body);
51         scopes.shrink_to_fit();
52         Arc::new(scopes)
53     }
54 
entries(&self, scope: ScopeId) -> &[ScopeEntry]55     pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
56         &self.scope_entries[self.scopes[scope].entries.clone()]
57     }
58 
59     /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
block(&self, scope: ScopeId) -> Option<BlockId>60     pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
61         self.scopes[scope].block
62     }
63 
64     /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
label(&self, scope: ScopeId) -> Option<(LabelId, Name)>65     pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
66         self.scopes[scope].label.clone()
67     }
68 
69     /// Returns the scopes in ascending order.
scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_70     pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
71         std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
72     }
73 
resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry>74     pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
75         self.scope_chain(Some(scope))
76             .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
77     }
78 
scope_for(&self, expr: ExprId) -> Option<ScopeId>79     pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
80         self.scope_by_expr.get(&expr).copied()
81     }
82 
scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId>83     pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
84         &self.scope_by_expr
85     }
86 }
87 
empty_entries(idx: usize) -> IdxRange<ScopeEntry>88 fn empty_entries(idx: usize) -> IdxRange<ScopeEntry> {
89     IdxRange::new(Idx::from_raw(RawIdx::from(idx as u32))..Idx::from_raw(RawIdx::from(idx as u32)))
90 }
91 
92 impl ExprScopes {
new(body: &Body) -> ExprScopes93     fn new(body: &Body) -> ExprScopes {
94         let mut scopes = ExprScopes {
95             scopes: Arena::default(),
96             scope_entries: Arena::default(),
97             scope_by_expr: FxHashMap::default(),
98         };
99         let mut root = scopes.root_scope();
100         scopes.add_params_bindings(body, root, &body.params);
101         compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root);
102         scopes
103     }
104 
root_scope(&mut self) -> ScopeId105     fn root_scope(&mut self) -> ScopeId {
106         self.scopes.alloc(ScopeData {
107             parent: None,
108             block: None,
109             label: None,
110             entries: empty_entries(self.scope_entries.len()),
111         })
112     }
113 
new_scope(&mut self, parent: ScopeId) -> ScopeId114     fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
115         self.scopes.alloc(ScopeData {
116             parent: Some(parent),
117             block: None,
118             label: None,
119             entries: empty_entries(self.scope_entries.len()),
120         })
121     }
122 
new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId123     fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
124         self.scopes.alloc(ScopeData {
125             parent: Some(parent),
126             block: None,
127             label,
128             entries: empty_entries(self.scope_entries.len()),
129         })
130     }
131 
new_block_scope( &mut self, parent: ScopeId, block: Option<BlockId>, label: Option<(LabelId, Name)>, ) -> ScopeId132     fn new_block_scope(
133         &mut self,
134         parent: ScopeId,
135         block: Option<BlockId>,
136         label: Option<(LabelId, Name)>,
137     ) -> ScopeId {
138         self.scopes.alloc(ScopeData {
139             parent: Some(parent),
140             block,
141             label,
142             entries: empty_entries(self.scope_entries.len()),
143         })
144     }
145 
add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId)146     fn add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId) {
147         let Binding { name, .. } = &body.bindings[binding];
148         let entry = self.scope_entries.alloc(ScopeEntry { name: name.clone(), binding });
149         self.scopes[scope].entries =
150             IdxRange::new_inclusive(self.scopes[scope].entries.start()..=entry);
151     }
152 
add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId)153     fn add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
154         let pattern = &body[pat];
155         if let Pat::Bind { id, .. } = pattern {
156             self.add_bindings(body, scope, *id);
157         }
158 
159         pattern.walk_child_pats(|pat| self.add_pat_bindings(body, scope, pat));
160     }
161 
add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId])162     fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
163         params.iter().for_each(|pat| self.add_pat_bindings(body, scope, *pat));
164     }
165 
set_scope(&mut self, node: ExprId, scope: ScopeId)166     fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
167         self.scope_by_expr.insert(node, scope);
168     }
169 
shrink_to_fit(&mut self)170     fn shrink_to_fit(&mut self) {
171         let ExprScopes { scopes, scope_entries, scope_by_expr } = self;
172         scopes.shrink_to_fit();
173         scope_entries.shrink_to_fit();
174         scope_by_expr.shrink_to_fit();
175     }
176 }
177 
compute_block_scopes( statements: &[Statement], tail: Option<ExprId>, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId, )178 fn compute_block_scopes(
179     statements: &[Statement],
180     tail: Option<ExprId>,
181     body: &Body,
182     scopes: &mut ExprScopes,
183     scope: &mut ScopeId,
184 ) {
185     for stmt in statements {
186         match stmt {
187             Statement::Let { pat, initializer, else_branch, .. } => {
188                 if let Some(expr) = initializer {
189                     compute_expr_scopes(*expr, body, scopes, scope);
190                 }
191                 if let Some(expr) = else_branch {
192                     compute_expr_scopes(*expr, body, scopes, scope);
193                 }
194 
195                 *scope = scopes.new_scope(*scope);
196                 scopes.add_pat_bindings(body, *scope, *pat);
197             }
198             Statement::Expr { expr, .. } => {
199                 compute_expr_scopes(*expr, body, scopes, scope);
200             }
201         }
202     }
203     if let Some(expr) = tail {
204         compute_expr_scopes(expr, body, scopes, scope);
205     }
206 }
207 
compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId)208 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId) {
209     let make_label =
210         |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
211 
212     scopes.set_scope(expr, *scope);
213     match &body[expr] {
214         Expr::Block { statements, tail, id, label } => {
215             let mut scope = scopes.new_block_scope(*scope, *id, make_label(label));
216             // Overwrite the old scope for the block expr, so that every block scope can be found
217             // via the block itself (important for blocks that only contain items, no expressions).
218             scopes.set_scope(expr, scope);
219             compute_block_scopes(statements, *tail, body, scopes, &mut scope);
220         }
221         Expr::Const(_) => {
222             // FIXME: This is broken.
223         }
224         Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } => {
225             let mut scope = scopes.new_block_scope(*scope, *id, None);
226             // Overwrite the old scope for the block expr, so that every block scope can be found
227             // via the block itself (important for blocks that only contain items, no expressions).
228             scopes.set_scope(expr, scope);
229             compute_block_scopes(statements, *tail, body, scopes, &mut scope);
230         }
231         Expr::While { condition, body: body_expr, label } => {
232             let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
233             compute_expr_scopes(*condition, body, scopes, &mut scope);
234             compute_expr_scopes(*body_expr, body, scopes, &mut scope);
235         }
236         Expr::Loop { body: body_expr, label } => {
237             let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
238             compute_expr_scopes(*body_expr, body, scopes, &mut scope);
239         }
240         Expr::Closure { args, body: body_expr, .. } => {
241             let mut scope = scopes.new_scope(*scope);
242             scopes.add_params_bindings(body, scope, args);
243             compute_expr_scopes(*body_expr, body, scopes, &mut scope);
244         }
245         Expr::Match { expr, arms } => {
246             compute_expr_scopes(*expr, body, scopes, scope);
247             for arm in arms.iter() {
248                 let mut scope = scopes.new_scope(*scope);
249                 scopes.add_pat_bindings(body, scope, arm.pat);
250                 if let Some(guard) = arm.guard {
251                     scope = scopes.new_scope(scope);
252                     compute_expr_scopes(guard, body, scopes, &mut scope);
253                 }
254                 compute_expr_scopes(arm.expr, body, scopes, &mut scope);
255             }
256         }
257         &Expr::If { condition, then_branch, else_branch } => {
258             let mut then_branch_scope = scopes.new_scope(*scope);
259             compute_expr_scopes(condition, body, scopes, &mut then_branch_scope);
260             compute_expr_scopes(then_branch, body, scopes, &mut then_branch_scope);
261             if let Some(else_branch) = else_branch {
262                 compute_expr_scopes(else_branch, body, scopes, scope);
263             }
264         }
265         &Expr::Let { pat, expr } => {
266             compute_expr_scopes(expr, body, scopes, scope);
267             *scope = scopes.new_scope(*scope);
268             scopes.add_pat_bindings(body, *scope, pat);
269         }
270         e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
271     };
272 }
273 
274 #[cfg(test)]
275 mod tests {
276     use base_db::{fixture::WithFixture, FileId, SourceDatabase};
277     use hir_expand::{name::AsName, InFile};
278     use syntax::{algo::find_node_at_offset, ast, AstNode};
279     use test_utils::{assert_eq_text, extract_offset};
280 
281     use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
282 
find_function(db: &TestDB, file_id: FileId) -> FunctionId283     fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
284         let krate = db.test_crate();
285         let crate_def_map = db.crate_def_map(krate);
286 
287         let module = crate_def_map.modules_for_file(file_id).next().unwrap();
288         let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
289         match def.take_values().unwrap() {
290             ModuleDefId::FunctionId(it) => it,
291             _ => panic!(),
292         }
293     }
294 
do_check(ra_fixture: &str, expected: &[&str])295     fn do_check(ra_fixture: &str, expected: &[&str]) {
296         let (offset, code) = extract_offset(ra_fixture);
297         let code = {
298             let mut buf = String::new();
299             let off: usize = offset.into();
300             buf.push_str(&code[..off]);
301             buf.push_str("$0marker");
302             buf.push_str(&code[off..]);
303             buf
304         };
305 
306         let (db, position) = TestDB::with_position(&code);
307         let file_id = position.file_id;
308         let offset = position.offset;
309 
310         let file_syntax = db.parse(file_id).syntax_node();
311         let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
312         let function = find_function(&db, file_id);
313 
314         let scopes = db.expr_scopes(function.into());
315         let (_body, source_map) = db.body_with_source_map(function.into());
316 
317         let expr_id = source_map
318             .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
319             .unwrap();
320         let scope = scopes.scope_for(expr_id);
321 
322         let actual = scopes
323             .scope_chain(scope)
324             .flat_map(|scope| scopes.entries(scope))
325             .map(|it| it.name().to_smol_str())
326             .collect::<Vec<_>>()
327             .join("\n");
328         let expected = expected.join("\n");
329         assert_eq_text!(&expected, &actual);
330     }
331 
332     #[test]
test_lambda_scope()333     fn test_lambda_scope() {
334         do_check(
335             r"
336             fn quux(foo: i32) {
337                 let f = |bar, baz: i32| {
338                     $0
339                 };
340             }",
341             &["bar", "baz", "foo"],
342         );
343     }
344 
345     #[test]
test_call_scope()346     fn test_call_scope() {
347         do_check(
348             r"
349             fn quux() {
350                 f(|x| $0 );
351             }",
352             &["x"],
353         );
354     }
355 
356     #[test]
test_method_call_scope()357     fn test_method_call_scope() {
358         do_check(
359             r"
360             fn quux() {
361                 z.f(|x| $0 );
362             }",
363             &["x"],
364         );
365     }
366 
367     #[test]
test_loop_scope()368     fn test_loop_scope() {
369         do_check(
370             r"
371             fn quux() {
372                 loop {
373                     let x = ();
374                     $0
375                 };
376             }",
377             &["x"],
378         );
379     }
380 
381     #[test]
test_match()382     fn test_match() {
383         do_check(
384             r"
385             fn quux() {
386                 match () {
387                     Some(x) => {
388                         $0
389                     }
390                 };
391             }",
392             &["x"],
393         );
394     }
395 
396     #[test]
test_shadow_variable()397     fn test_shadow_variable() {
398         do_check(
399             r"
400             fn foo(x: String) {
401                 let x : &str = &x$0;
402             }",
403             &["x"],
404         );
405     }
406 
407     #[test]
test_bindings_after_at()408     fn test_bindings_after_at() {
409         do_check(
410             r"
411 fn foo() {
412     match Some(()) {
413         opt @ Some(unit) => {
414             $0
415         }
416         _ => {}
417     }
418 }
419 ",
420             &["opt", "unit"],
421         );
422     }
423 
424     #[test]
macro_inner_item()425     fn macro_inner_item() {
426         do_check(
427             r"
428             macro_rules! mac {
429                 () => {{
430                     fn inner() {}
431                     inner();
432                 }};
433             }
434 
435             fn foo() {
436                 mac!();
437                 $0
438             }
439         ",
440             &[],
441         );
442     }
443 
444     #[test]
broken_inner_item()445     fn broken_inner_item() {
446         do_check(
447             r"
448             fn foo() {
449                 trait {}
450                 $0
451             }
452         ",
453             &[],
454         );
455     }
456 
do_check_local_name(ra_fixture: &str, expected_offset: u32)457     fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
458         let (db, position) = TestDB::with_position(ra_fixture);
459         let file_id = position.file_id;
460         let offset = position.offset;
461 
462         let file = db.parse(file_id).ok().unwrap();
463         let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
464             .expect("failed to find a name at the target offset");
465         let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
466 
467         let function = find_function(&db, file_id);
468 
469         let scopes = db.expr_scopes(function.into());
470         let (body, source_map) = db.body_with_source_map(function.into());
471 
472         let expr_scope = {
473             let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
474             let expr_id =
475                 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
476             scopes.scope_for(expr_id).unwrap()
477         };
478 
479         let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
480         let pat_src = source_map
481             .pat_syntax(*body.bindings[resolved.binding()].definitions.first().unwrap())
482             .unwrap();
483 
484         let local_name = pat_src.value.either(
485             |it| it.syntax_node_ptr().to_node(file.syntax()),
486             |it| it.syntax_node_ptr().to_node(file.syntax()),
487         );
488         assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
489     }
490 
491     #[test]
test_resolve_local_name()492     fn test_resolve_local_name() {
493         do_check_local_name(
494             r#"
495 fn foo(x: i32, y: u32) {
496     {
497         let z = x * 2;
498     }
499     {
500         let t = x$0 * 3;
501     }
502 }
503 "#,
504             7,
505         );
506     }
507 
508     #[test]
test_resolve_local_name_declaration()509     fn test_resolve_local_name_declaration() {
510         do_check_local_name(
511             r#"
512 fn foo(x: String) {
513     let x : &str = &x$0;
514 }
515 "#,
516             7,
517         );
518     }
519 
520     #[test]
test_resolve_local_name_shadow()521     fn test_resolve_local_name_shadow() {
522         do_check_local_name(
523             r"
524 fn foo(x: String) {
525     let x : &str = &x;
526     x$0
527 }
528 ",
529             28,
530         );
531     }
532 
533     #[test]
ref_patterns_contribute_bindings()534     fn ref_patterns_contribute_bindings() {
535         do_check_local_name(
536             r"
537 fn foo() {
538     if let Some(&from) = bar() {
539         from$0;
540     }
541 }
542 ",
543             28,
544         );
545     }
546 
547     #[test]
while_let_adds_binding()548     fn while_let_adds_binding() {
549         do_check_local_name(
550             r#"
551 fn test() {
552     let foo: Option<f32> = None;
553     while let Option::Some(spam) = foo {
554         spam$0
555     }
556 }
557 "#,
558             75,
559         );
560         do_check_local_name(
561             r#"
562 fn test() {
563     let foo: Option<f32> = None;
564     while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo {
565         spam$0
566     }
567 }
568 "#,
569             107,
570         );
571     }
572 
573     #[test]
match_guard_if_let()574     fn match_guard_if_let() {
575         do_check_local_name(
576             r#"
577 fn test() {
578     let foo: Option<f32> = None;
579     match foo {
580         _ if let Option::Some(spam) = foo => spam$0,
581     }
582 }
583 "#,
584             93,
585         );
586     }
587 
588     #[test]
let_chains_can_reference_previous_lets()589     fn let_chains_can_reference_previous_lets() {
590         do_check_local_name(
591             r#"
592 fn test() {
593     let foo: Option<i32> = None;
594     if let Some(spam) = foo && spa$0m > 1 && let Some(spam) = foo && spam > 1 {}
595 }
596 "#,
597             61,
598         );
599         do_check_local_name(
600             r#"
601 fn test() {
602     let foo: Option<i32> = None;
603     if let Some(spam) = foo && spam > 1 && let Some(spam) = foo && sp$0am > 1 {}
604 }
605 "#,
606             100,
607         );
608     }
609 }
610