1 //! Name resolution for expressions.
2 use hir_expand::name::Name;
3 use la_arena::{Arena, Idx, IdxRange, RawIdx};
4 use rustc_hash::FxHashMap;
5 use triomphe::Arc;
6
7 use crate::{
8 body::Body,
9 db::DefDatabase,
10 hir::{Binding, BindingId, Expr, ExprId, LabelId, Pat, PatId, Statement},
11 BlockId, DefWithBodyId,
12 };
13
14 pub type ScopeId = Idx<ScopeData>;
15
16 #[derive(Debug, PartialEq, Eq)]
17 pub struct ExprScopes {
18 scopes: Arena<ScopeData>,
19 scope_entries: Arena<ScopeEntry>,
20 scope_by_expr: FxHashMap<ExprId, ScopeId>,
21 }
22
23 #[derive(Debug, PartialEq, Eq)]
24 pub struct ScopeEntry {
25 name: Name,
26 binding: BindingId,
27 }
28
29 impl ScopeEntry {
name(&self) -> &Name30 pub fn name(&self) -> &Name {
31 &self.name
32 }
33
binding(&self) -> BindingId34 pub fn binding(&self) -> BindingId {
35 self.binding
36 }
37 }
38
39 #[derive(Debug, PartialEq, Eq)]
40 pub struct ScopeData {
41 parent: Option<ScopeId>,
42 block: Option<BlockId>,
43 label: Option<(LabelId, Name)>,
44 entries: IdxRange<ScopeEntry>,
45 }
46
47 impl ExprScopes {
expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes>48 pub(crate) fn expr_scopes_query(db: &dyn DefDatabase, def: DefWithBodyId) -> Arc<ExprScopes> {
49 let body = db.body(def);
50 let mut scopes = ExprScopes::new(&body);
51 scopes.shrink_to_fit();
52 Arc::new(scopes)
53 }
54
entries(&self, scope: ScopeId) -> &[ScopeEntry]55 pub fn entries(&self, scope: ScopeId) -> &[ScopeEntry] {
56 &self.scope_entries[self.scopes[scope].entries.clone()]
57 }
58
59 /// If `scope` refers to a block expression scope, returns the corresponding `BlockId`.
block(&self, scope: ScopeId) -> Option<BlockId>60 pub fn block(&self, scope: ScopeId) -> Option<BlockId> {
61 self.scopes[scope].block
62 }
63
64 /// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
label(&self, scope: ScopeId) -> Option<(LabelId, Name)>65 pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
66 self.scopes[scope].label.clone()
67 }
68
69 /// Returns the scopes in ascending order.
scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_70 pub fn scope_chain(&self, scope: Option<ScopeId>) -> impl Iterator<Item = ScopeId> + '_ {
71 std::iter::successors(scope, move |&scope| self.scopes[scope].parent)
72 }
73
resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry>74 pub fn resolve_name_in_scope(&self, scope: ScopeId, name: &Name) -> Option<&ScopeEntry> {
75 self.scope_chain(Some(scope))
76 .find_map(|scope| self.entries(scope).iter().find(|it| it.name == *name))
77 }
78
scope_for(&self, expr: ExprId) -> Option<ScopeId>79 pub fn scope_for(&self, expr: ExprId) -> Option<ScopeId> {
80 self.scope_by_expr.get(&expr).copied()
81 }
82
scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId>83 pub fn scope_by_expr(&self) -> &FxHashMap<ExprId, ScopeId> {
84 &self.scope_by_expr
85 }
86 }
87
empty_entries(idx: usize) -> IdxRange<ScopeEntry>88 fn empty_entries(idx: usize) -> IdxRange<ScopeEntry> {
89 IdxRange::new(Idx::from_raw(RawIdx::from(idx as u32))..Idx::from_raw(RawIdx::from(idx as u32)))
90 }
91
92 impl ExprScopes {
new(body: &Body) -> ExprScopes93 fn new(body: &Body) -> ExprScopes {
94 let mut scopes = ExprScopes {
95 scopes: Arena::default(),
96 scope_entries: Arena::default(),
97 scope_by_expr: FxHashMap::default(),
98 };
99 let mut root = scopes.root_scope();
100 scopes.add_params_bindings(body, root, &body.params);
101 compute_expr_scopes(body.body_expr, body, &mut scopes, &mut root);
102 scopes
103 }
104
root_scope(&mut self) -> ScopeId105 fn root_scope(&mut self) -> ScopeId {
106 self.scopes.alloc(ScopeData {
107 parent: None,
108 block: None,
109 label: None,
110 entries: empty_entries(self.scope_entries.len()),
111 })
112 }
113
new_scope(&mut self, parent: ScopeId) -> ScopeId114 fn new_scope(&mut self, parent: ScopeId) -> ScopeId {
115 self.scopes.alloc(ScopeData {
116 parent: Some(parent),
117 block: None,
118 label: None,
119 entries: empty_entries(self.scope_entries.len()),
120 })
121 }
122
new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId123 fn new_labeled_scope(&mut self, parent: ScopeId, label: Option<(LabelId, Name)>) -> ScopeId {
124 self.scopes.alloc(ScopeData {
125 parent: Some(parent),
126 block: None,
127 label,
128 entries: empty_entries(self.scope_entries.len()),
129 })
130 }
131
new_block_scope( &mut self, parent: ScopeId, block: Option<BlockId>, label: Option<(LabelId, Name)>, ) -> ScopeId132 fn new_block_scope(
133 &mut self,
134 parent: ScopeId,
135 block: Option<BlockId>,
136 label: Option<(LabelId, Name)>,
137 ) -> ScopeId {
138 self.scopes.alloc(ScopeData {
139 parent: Some(parent),
140 block,
141 label,
142 entries: empty_entries(self.scope_entries.len()),
143 })
144 }
145
add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId)146 fn add_bindings(&mut self, body: &Body, scope: ScopeId, binding: BindingId) {
147 let Binding { name, .. } = &body.bindings[binding];
148 let entry = self.scope_entries.alloc(ScopeEntry { name: name.clone(), binding });
149 self.scopes[scope].entries =
150 IdxRange::new_inclusive(self.scopes[scope].entries.start()..=entry);
151 }
152
add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId)153 fn add_pat_bindings(&mut self, body: &Body, scope: ScopeId, pat: PatId) {
154 let pattern = &body[pat];
155 if let Pat::Bind { id, .. } = pattern {
156 self.add_bindings(body, scope, *id);
157 }
158
159 pattern.walk_child_pats(|pat| self.add_pat_bindings(body, scope, pat));
160 }
161
add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId])162 fn add_params_bindings(&mut self, body: &Body, scope: ScopeId, params: &[PatId]) {
163 params.iter().for_each(|pat| self.add_pat_bindings(body, scope, *pat));
164 }
165
set_scope(&mut self, node: ExprId, scope: ScopeId)166 fn set_scope(&mut self, node: ExprId, scope: ScopeId) {
167 self.scope_by_expr.insert(node, scope);
168 }
169
shrink_to_fit(&mut self)170 fn shrink_to_fit(&mut self) {
171 let ExprScopes { scopes, scope_entries, scope_by_expr } = self;
172 scopes.shrink_to_fit();
173 scope_entries.shrink_to_fit();
174 scope_by_expr.shrink_to_fit();
175 }
176 }
177
compute_block_scopes( statements: &[Statement], tail: Option<ExprId>, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId, )178 fn compute_block_scopes(
179 statements: &[Statement],
180 tail: Option<ExprId>,
181 body: &Body,
182 scopes: &mut ExprScopes,
183 scope: &mut ScopeId,
184 ) {
185 for stmt in statements {
186 match stmt {
187 Statement::Let { pat, initializer, else_branch, .. } => {
188 if let Some(expr) = initializer {
189 compute_expr_scopes(*expr, body, scopes, scope);
190 }
191 if let Some(expr) = else_branch {
192 compute_expr_scopes(*expr, body, scopes, scope);
193 }
194
195 *scope = scopes.new_scope(*scope);
196 scopes.add_pat_bindings(body, *scope, *pat);
197 }
198 Statement::Expr { expr, .. } => {
199 compute_expr_scopes(*expr, body, scopes, scope);
200 }
201 }
202 }
203 if let Some(expr) = tail {
204 compute_expr_scopes(expr, body, scopes, scope);
205 }
206 }
207
compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId)208 fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: &mut ScopeId) {
209 let make_label =
210 |label: &Option<LabelId>| label.map(|label| (label, body.labels[label].name.clone()));
211
212 scopes.set_scope(expr, *scope);
213 match &body[expr] {
214 Expr::Block { statements, tail, id, label } => {
215 let mut scope = scopes.new_block_scope(*scope, *id, make_label(label));
216 // Overwrite the old scope for the block expr, so that every block scope can be found
217 // via the block itself (important for blocks that only contain items, no expressions).
218 scopes.set_scope(expr, scope);
219 compute_block_scopes(statements, *tail, body, scopes, &mut scope);
220 }
221 Expr::Const(_) => {
222 // FIXME: This is broken.
223 }
224 Expr::Unsafe { id, statements, tail } | Expr::Async { id, statements, tail } => {
225 let mut scope = scopes.new_block_scope(*scope, *id, None);
226 // Overwrite the old scope for the block expr, so that every block scope can be found
227 // via the block itself (important for blocks that only contain items, no expressions).
228 scopes.set_scope(expr, scope);
229 compute_block_scopes(statements, *tail, body, scopes, &mut scope);
230 }
231 Expr::While { condition, body: body_expr, label } => {
232 let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
233 compute_expr_scopes(*condition, body, scopes, &mut scope);
234 compute_expr_scopes(*body_expr, body, scopes, &mut scope);
235 }
236 Expr::Loop { body: body_expr, label } => {
237 let mut scope = scopes.new_labeled_scope(*scope, make_label(label));
238 compute_expr_scopes(*body_expr, body, scopes, &mut scope);
239 }
240 Expr::Closure { args, body: body_expr, .. } => {
241 let mut scope = scopes.new_scope(*scope);
242 scopes.add_params_bindings(body, scope, args);
243 compute_expr_scopes(*body_expr, body, scopes, &mut scope);
244 }
245 Expr::Match { expr, arms } => {
246 compute_expr_scopes(*expr, body, scopes, scope);
247 for arm in arms.iter() {
248 let mut scope = scopes.new_scope(*scope);
249 scopes.add_pat_bindings(body, scope, arm.pat);
250 if let Some(guard) = arm.guard {
251 scope = scopes.new_scope(scope);
252 compute_expr_scopes(guard, body, scopes, &mut scope);
253 }
254 compute_expr_scopes(arm.expr, body, scopes, &mut scope);
255 }
256 }
257 &Expr::If { condition, then_branch, else_branch } => {
258 let mut then_branch_scope = scopes.new_scope(*scope);
259 compute_expr_scopes(condition, body, scopes, &mut then_branch_scope);
260 compute_expr_scopes(then_branch, body, scopes, &mut then_branch_scope);
261 if let Some(else_branch) = else_branch {
262 compute_expr_scopes(else_branch, body, scopes, scope);
263 }
264 }
265 &Expr::Let { pat, expr } => {
266 compute_expr_scopes(expr, body, scopes, scope);
267 *scope = scopes.new_scope(*scope);
268 scopes.add_pat_bindings(body, *scope, pat);
269 }
270 e => e.walk_child_exprs(|e| compute_expr_scopes(e, body, scopes, scope)),
271 };
272 }
273
274 #[cfg(test)]
275 mod tests {
276 use base_db::{fixture::WithFixture, FileId, SourceDatabase};
277 use hir_expand::{name::AsName, InFile};
278 use syntax::{algo::find_node_at_offset, ast, AstNode};
279 use test_utils::{assert_eq_text, extract_offset};
280
281 use crate::{db::DefDatabase, test_db::TestDB, FunctionId, ModuleDefId};
282
find_function(db: &TestDB, file_id: FileId) -> FunctionId283 fn find_function(db: &TestDB, file_id: FileId) -> FunctionId {
284 let krate = db.test_crate();
285 let crate_def_map = db.crate_def_map(krate);
286
287 let module = crate_def_map.modules_for_file(file_id).next().unwrap();
288 let (_, def) = crate_def_map[module].scope.entries().next().unwrap();
289 match def.take_values().unwrap() {
290 ModuleDefId::FunctionId(it) => it,
291 _ => panic!(),
292 }
293 }
294
do_check(ra_fixture: &str, expected: &[&str])295 fn do_check(ra_fixture: &str, expected: &[&str]) {
296 let (offset, code) = extract_offset(ra_fixture);
297 let code = {
298 let mut buf = String::new();
299 let off: usize = offset.into();
300 buf.push_str(&code[..off]);
301 buf.push_str("$0marker");
302 buf.push_str(&code[off..]);
303 buf
304 };
305
306 let (db, position) = TestDB::with_position(&code);
307 let file_id = position.file_id;
308 let offset = position.offset;
309
310 let file_syntax = db.parse(file_id).syntax_node();
311 let marker: ast::PathExpr = find_node_at_offset(&file_syntax, offset).unwrap();
312 let function = find_function(&db, file_id);
313
314 let scopes = db.expr_scopes(function.into());
315 let (_body, source_map) = db.body_with_source_map(function.into());
316
317 let expr_id = source_map
318 .node_expr(InFile { file_id: file_id.into(), value: &marker.into() })
319 .unwrap();
320 let scope = scopes.scope_for(expr_id);
321
322 let actual = scopes
323 .scope_chain(scope)
324 .flat_map(|scope| scopes.entries(scope))
325 .map(|it| it.name().to_smol_str())
326 .collect::<Vec<_>>()
327 .join("\n");
328 let expected = expected.join("\n");
329 assert_eq_text!(&expected, &actual);
330 }
331
332 #[test]
test_lambda_scope()333 fn test_lambda_scope() {
334 do_check(
335 r"
336 fn quux(foo: i32) {
337 let f = |bar, baz: i32| {
338 $0
339 };
340 }",
341 &["bar", "baz", "foo"],
342 );
343 }
344
345 #[test]
test_call_scope()346 fn test_call_scope() {
347 do_check(
348 r"
349 fn quux() {
350 f(|x| $0 );
351 }",
352 &["x"],
353 );
354 }
355
356 #[test]
test_method_call_scope()357 fn test_method_call_scope() {
358 do_check(
359 r"
360 fn quux() {
361 z.f(|x| $0 );
362 }",
363 &["x"],
364 );
365 }
366
367 #[test]
test_loop_scope()368 fn test_loop_scope() {
369 do_check(
370 r"
371 fn quux() {
372 loop {
373 let x = ();
374 $0
375 };
376 }",
377 &["x"],
378 );
379 }
380
381 #[test]
test_match()382 fn test_match() {
383 do_check(
384 r"
385 fn quux() {
386 match () {
387 Some(x) => {
388 $0
389 }
390 };
391 }",
392 &["x"],
393 );
394 }
395
396 #[test]
test_shadow_variable()397 fn test_shadow_variable() {
398 do_check(
399 r"
400 fn foo(x: String) {
401 let x : &str = &x$0;
402 }",
403 &["x"],
404 );
405 }
406
407 #[test]
test_bindings_after_at()408 fn test_bindings_after_at() {
409 do_check(
410 r"
411 fn foo() {
412 match Some(()) {
413 opt @ Some(unit) => {
414 $0
415 }
416 _ => {}
417 }
418 }
419 ",
420 &["opt", "unit"],
421 );
422 }
423
424 #[test]
macro_inner_item()425 fn macro_inner_item() {
426 do_check(
427 r"
428 macro_rules! mac {
429 () => {{
430 fn inner() {}
431 inner();
432 }};
433 }
434
435 fn foo() {
436 mac!();
437 $0
438 }
439 ",
440 &[],
441 );
442 }
443
444 #[test]
broken_inner_item()445 fn broken_inner_item() {
446 do_check(
447 r"
448 fn foo() {
449 trait {}
450 $0
451 }
452 ",
453 &[],
454 );
455 }
456
do_check_local_name(ra_fixture: &str, expected_offset: u32)457 fn do_check_local_name(ra_fixture: &str, expected_offset: u32) {
458 let (db, position) = TestDB::with_position(ra_fixture);
459 let file_id = position.file_id;
460 let offset = position.offset;
461
462 let file = db.parse(file_id).ok().unwrap();
463 let expected_name = find_node_at_offset::<ast::Name>(file.syntax(), expected_offset.into())
464 .expect("failed to find a name at the target offset");
465 let name_ref: ast::NameRef = find_node_at_offset(file.syntax(), offset).unwrap();
466
467 let function = find_function(&db, file_id);
468
469 let scopes = db.expr_scopes(function.into());
470 let (body, source_map) = db.body_with_source_map(function.into());
471
472 let expr_scope = {
473 let expr_ast = name_ref.syntax().ancestors().find_map(ast::Expr::cast).unwrap();
474 let expr_id =
475 source_map.node_expr(InFile { file_id: file_id.into(), value: &expr_ast }).unwrap();
476 scopes.scope_for(expr_id).unwrap()
477 };
478
479 let resolved = scopes.resolve_name_in_scope(expr_scope, &name_ref.as_name()).unwrap();
480 let pat_src = source_map
481 .pat_syntax(*body.bindings[resolved.binding()].definitions.first().unwrap())
482 .unwrap();
483
484 let local_name = pat_src.value.either(
485 |it| it.syntax_node_ptr().to_node(file.syntax()),
486 |it| it.syntax_node_ptr().to_node(file.syntax()),
487 );
488 assert_eq!(local_name.text_range(), expected_name.syntax().text_range());
489 }
490
491 #[test]
test_resolve_local_name()492 fn test_resolve_local_name() {
493 do_check_local_name(
494 r#"
495 fn foo(x: i32, y: u32) {
496 {
497 let z = x * 2;
498 }
499 {
500 let t = x$0 * 3;
501 }
502 }
503 "#,
504 7,
505 );
506 }
507
508 #[test]
test_resolve_local_name_declaration()509 fn test_resolve_local_name_declaration() {
510 do_check_local_name(
511 r#"
512 fn foo(x: String) {
513 let x : &str = &x$0;
514 }
515 "#,
516 7,
517 );
518 }
519
520 #[test]
test_resolve_local_name_shadow()521 fn test_resolve_local_name_shadow() {
522 do_check_local_name(
523 r"
524 fn foo(x: String) {
525 let x : &str = &x;
526 x$0
527 }
528 ",
529 28,
530 );
531 }
532
533 #[test]
ref_patterns_contribute_bindings()534 fn ref_patterns_contribute_bindings() {
535 do_check_local_name(
536 r"
537 fn foo() {
538 if let Some(&from) = bar() {
539 from$0;
540 }
541 }
542 ",
543 28,
544 );
545 }
546
547 #[test]
while_let_adds_binding()548 fn while_let_adds_binding() {
549 do_check_local_name(
550 r#"
551 fn test() {
552 let foo: Option<f32> = None;
553 while let Option::Some(spam) = foo {
554 spam$0
555 }
556 }
557 "#,
558 75,
559 );
560 do_check_local_name(
561 r#"
562 fn test() {
563 let foo: Option<f32> = None;
564 while (((let Option::Some(_) = foo))) && let Option::Some(spam) = foo {
565 spam$0
566 }
567 }
568 "#,
569 107,
570 );
571 }
572
573 #[test]
match_guard_if_let()574 fn match_guard_if_let() {
575 do_check_local_name(
576 r#"
577 fn test() {
578 let foo: Option<f32> = None;
579 match foo {
580 _ if let Option::Some(spam) = foo => spam$0,
581 }
582 }
583 "#,
584 93,
585 );
586 }
587
588 #[test]
let_chains_can_reference_previous_lets()589 fn let_chains_can_reference_previous_lets() {
590 do_check_local_name(
591 r#"
592 fn test() {
593 let foo: Option<i32> = None;
594 if let Some(spam) = foo && spa$0m > 1 && let Some(spam) = foo && spam > 1 {}
595 }
596 "#,
597 61,
598 );
599 do_check_local_name(
600 r#"
601 fn test() {
602 let foo: Option<i32> = None;
603 if let Some(spam) = foo && spam > 1 && let Some(spam) = foo && sp$0am > 1 {}
604 }
605 "#,
606 100,
607 );
608 }
609 }
610