• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 mod never_type;
2 mod coercion;
3 mod regression;
4 mod simple;
5 mod patterns;
6 mod traits;
7 mod method_resolution;
8 mod macros;
9 mod display_source_code;
10 mod incremental;
11 mod diagnostics;
12 
13 use std::{collections::HashMap, env};
14 
15 use base_db::{fixture::WithFixture, FileRange, SourceDatabaseExt};
16 use expect_test::Expect;
17 use hir_def::{
18     body::{Body, BodySourceMap, SyntheticSyntax},
19     db::{DefDatabase, InternDatabase},
20     hir::{ExprId, Pat, PatId},
21     item_scope::ItemScope,
22     nameres::DefMap,
23     src::HasSource,
24     AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId,
25 };
26 use hir_expand::{db::ExpandDatabase, InFile};
27 use once_cell::race::OnceBool;
28 use stdx::format_to;
29 use syntax::{
30     ast::{self, AstNode, HasName},
31     SyntaxNode,
32 };
33 use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Registry};
34 use tracing_tree::HierarchicalLayer;
35 use triomphe::Arc;
36 
37 use crate::{
38     db::HirDatabase,
39     display::HirDisplay,
40     infer::{Adjustment, TypeMismatch},
41     test_db::TestDB,
42     InferenceResult, Ty,
43 };
44 
45 // These tests compare the inference results for all expressions in a file
46 // against snapshots of the expected results using expect. Use
47 // `env UPDATE_EXPECT=1 cargo test -p hir_ty` to update the snapshots.
48 
setup_tracing() -> Option<tracing::subscriber::DefaultGuard>49 fn setup_tracing() -> Option<tracing::subscriber::DefaultGuard> {
50     static ENABLE: OnceBool = OnceBool::new();
51     if !ENABLE.get_or_init(|| env::var("CHALK_DEBUG").is_ok()) {
52         return None;
53     }
54 
55     let filter = EnvFilter::from_env("CHALK_DEBUG");
56     let layer = HierarchicalLayer::default()
57         .with_indent_lines(true)
58         .with_ansi(false)
59         .with_indent_amount(2)
60         .with_writer(std::io::stderr);
61     let subscriber = Registry::default().with(filter).with(layer);
62     Some(tracing::subscriber::set_default(subscriber))
63 }
64 
65 #[track_caller]
check_types(ra_fixture: &str)66 fn check_types(ra_fixture: &str) {
67     check_impl(ra_fixture, false, true, false)
68 }
69 
70 #[track_caller]
check_types_source_code(ra_fixture: &str)71 fn check_types_source_code(ra_fixture: &str) {
72     check_impl(ra_fixture, false, true, true)
73 }
74 
75 #[track_caller]
check_no_mismatches(ra_fixture: &str)76 fn check_no_mismatches(ra_fixture: &str) {
77     check_impl(ra_fixture, true, false, false)
78 }
79 
80 #[track_caller]
check(ra_fixture: &str)81 fn check(ra_fixture: &str) {
82     check_impl(ra_fixture, false, false, false)
83 }
84 
85 #[track_caller]
check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_source: bool)86 fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_source: bool) {
87     let _tracing = setup_tracing();
88     let (db, files) = TestDB::with_many_files(ra_fixture);
89 
90     let mut had_annotations = false;
91     let mut mismatches = HashMap::new();
92     let mut types = HashMap::new();
93     let mut adjustments = HashMap::<_, Vec<_>>::new();
94     for (file_id, annotations) in db.extract_annotations() {
95         for (range, expected) in annotations {
96             let file_range = FileRange { file_id, range };
97             if only_types {
98                 types.insert(file_range, expected);
99             } else if expected.starts_with("type: ") {
100                 types.insert(file_range, expected.trim_start_matches("type: ").to_string());
101             } else if expected.starts_with("expected") {
102                 mismatches.insert(file_range, expected);
103             } else if expected.starts_with("adjustments:") {
104                 adjustments.insert(
105                     file_range,
106                     expected
107                         .trim_start_matches("adjustments:")
108                         .trim()
109                         .split(',')
110                         .map(|it| it.trim().to_string())
111                         .filter(|it| !it.is_empty())
112                         .collect(),
113                 );
114             } else {
115                 panic!("unexpected annotation: {expected}");
116             }
117             had_annotations = true;
118         }
119     }
120     assert!(had_annotations || allow_none, "no `//^` annotations found");
121 
122     let mut defs: Vec<DefWithBodyId> = Vec::new();
123     for file_id in files {
124         let module = db.module_for_file_opt(file_id);
125         let module = match module {
126             Some(m) => m,
127             None => continue,
128         };
129         let def_map = module.def_map(&db);
130         visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
131     }
132     defs.sort_by_key(|def| match def {
133         DefWithBodyId::FunctionId(it) => {
134             let loc = it.lookup(&db);
135             loc.source(&db).value.syntax().text_range().start()
136         }
137         DefWithBodyId::ConstId(it) => {
138             let loc = it.lookup(&db);
139             loc.source(&db).value.syntax().text_range().start()
140         }
141         DefWithBodyId::StaticId(it) => {
142             let loc = it.lookup(&db);
143             loc.source(&db).value.syntax().text_range().start()
144         }
145         DefWithBodyId::VariantId(it) => {
146             let loc = db.lookup_intern_enum(it.parent);
147             loc.source(&db).value.syntax().text_range().start()
148         }
149         DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(),
150     });
151     let mut unexpected_type_mismatches = String::new();
152     for def in defs {
153         let (body, body_source_map) = db.body_with_source_map(def);
154         let inference_result = db.infer(def);
155 
156         for (pat, mut ty) in inference_result.type_of_pat.iter() {
157             if let Pat::Bind { id, .. } = body.pats[pat] {
158                 ty = &inference_result.type_of_binding[id];
159             }
160             let node = match pat_node(&body_source_map, pat, &db) {
161                 Some(value) => value,
162                 None => continue,
163             };
164             let range = node.as_ref().original_file_range(&db);
165             if let Some(expected) = types.remove(&range) {
166                 let actual = if display_source {
167                     ty.display_source_code(&db, def.module(&db), true).unwrap()
168                 } else {
169                     ty.display_test(&db).to_string()
170                 };
171                 assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
172             }
173         }
174 
175         for (expr, ty) in inference_result.type_of_expr.iter() {
176             let node = match expr_node(&body_source_map, expr, &db) {
177                 Some(value) => value,
178                 None => continue,
179             };
180             let range = node.as_ref().original_file_range(&db);
181             if let Some(expected) = types.remove(&range) {
182                 let actual = if display_source {
183                     ty.display_source_code(&db, def.module(&db), true).unwrap()
184                 } else {
185                     ty.display_test(&db).to_string()
186                 };
187                 assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
188             }
189             if let Some(expected) = adjustments.remove(&range) {
190                 let adjustments = inference_result
191                     .expr_adjustments
192                     .get(&expr)
193                     .map_or_else(Default::default, |it| &**it);
194                 assert_eq!(
195                     expected,
196                     adjustments
197                         .iter()
198                         .map(|Adjustment { kind, .. }| format!("{kind:?}"))
199                         .collect::<Vec<_>>()
200                 );
201             }
202         }
203 
204         for (expr_or_pat, mismatch) in inference_result.type_mismatches() {
205             let Some(node) = (match expr_or_pat {
206                 hir_def::hir::ExprOrPatId::ExprId(expr) => expr_node(&body_source_map, expr, &db),
207                 hir_def::hir::ExprOrPatId::PatId(pat) => pat_node(&body_source_map, pat, &db),
208             }) else { continue; };
209             let range = node.as_ref().original_file_range(&db);
210             let actual = format!(
211                 "expected {}, got {}",
212                 mismatch.expected.display_test(&db),
213                 mismatch.actual.display_test(&db)
214             );
215             match mismatches.remove(&range) {
216                 Some(annotation) => assert_eq!(actual, annotation),
217                 None => format_to!(unexpected_type_mismatches, "{:?}: {}\n", range.range, actual),
218             }
219         }
220     }
221 
222     let mut buf = String::new();
223     if !unexpected_type_mismatches.is_empty() {
224         format_to!(buf, "Unexpected type mismatches:\n{}", unexpected_type_mismatches);
225     }
226     if !mismatches.is_empty() {
227         format_to!(buf, "Unchecked mismatch annotations:\n");
228         for m in mismatches {
229             format_to!(buf, "{:?}: {}\n", m.0.range, m.1);
230         }
231     }
232     if !types.is_empty() {
233         format_to!(buf, "Unchecked type annotations:\n");
234         for t in types {
235             format_to!(buf, "{:?}: type {}\n", t.0.range, t.1);
236         }
237     }
238     if !adjustments.is_empty() {
239         format_to!(buf, "Unchecked adjustments annotations:\n");
240         for t in adjustments {
241             format_to!(buf, "{:?}: type {:?}\n", t.0.range, t.1);
242         }
243     }
244     assert!(buf.is_empty(), "{}", buf);
245 }
246 
expr_node( body_source_map: &BodySourceMap, expr: ExprId, db: &TestDB, ) -> Option<InFile<SyntaxNode>>247 fn expr_node(
248     body_source_map: &BodySourceMap,
249     expr: ExprId,
250     db: &TestDB,
251 ) -> Option<InFile<SyntaxNode>> {
252     Some(match body_source_map.expr_syntax(expr) {
253         Ok(sp) => {
254             let root = db.parse_or_expand(sp.file_id);
255             sp.map(|ptr| ptr.to_node(&root).syntax().clone())
256         }
257         Err(SyntheticSyntax) => return None,
258     })
259 }
260 
pat_node( body_source_map: &BodySourceMap, pat: PatId, db: &TestDB, ) -> Option<InFile<SyntaxNode>>261 fn pat_node(
262     body_source_map: &BodySourceMap,
263     pat: PatId,
264     db: &TestDB,
265 ) -> Option<InFile<SyntaxNode>> {
266     Some(match body_source_map.pat_syntax(pat) {
267         Ok(sp) => {
268             let root = db.parse_or_expand(sp.file_id);
269             sp.map(|ptr| {
270                 ptr.either(
271                     |it| it.to_node(&root).syntax().clone(),
272                     |it| it.to_node(&root).syntax().clone(),
273                 )
274             })
275         }
276         Err(SyntheticSyntax) => return None,
277     })
278 }
279 
infer(ra_fixture: &str) -> String280 fn infer(ra_fixture: &str) -> String {
281     infer_with_mismatches(ra_fixture, false)
282 }
283 
infer_with_mismatches(content: &str, include_mismatches: bool) -> String284 fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
285     let _tracing = setup_tracing();
286     let (db, file_id) = TestDB::with_single_file(content);
287 
288     let mut buf = String::new();
289 
290     let mut infer_def = |inference_result: Arc<InferenceResult>,
291                          body: Arc<Body>,
292                          body_source_map: Arc<BodySourceMap>| {
293         let mut types: Vec<(InFile<SyntaxNode>, &Ty)> = Vec::new();
294         let mut mismatches: Vec<(InFile<SyntaxNode>, &TypeMismatch)> = Vec::new();
295 
296         for (pat, mut ty) in inference_result.type_of_pat.iter() {
297             if let Pat::Bind { id, .. } = body.pats[pat] {
298                 ty = &inference_result.type_of_binding[id];
299             }
300             let syntax_ptr = match body_source_map.pat_syntax(pat) {
301                 Ok(sp) => {
302                     let root = db.parse_or_expand(sp.file_id);
303                     sp.map(|ptr| {
304                         ptr.either(
305                             |it| it.to_node(&root).syntax().clone(),
306                             |it| it.to_node(&root).syntax().clone(),
307                         )
308                     })
309                 }
310                 Err(SyntheticSyntax) => continue,
311             };
312             types.push((syntax_ptr.clone(), ty));
313             if let Some(mismatch) = inference_result.type_mismatch_for_pat(pat) {
314                 mismatches.push((syntax_ptr, mismatch));
315             }
316         }
317 
318         for (expr, ty) in inference_result.type_of_expr.iter() {
319             let node = match body_source_map.expr_syntax(expr) {
320                 Ok(sp) => {
321                     let root = db.parse_or_expand(sp.file_id);
322                     sp.map(|ptr| ptr.to_node(&root).syntax().clone())
323                 }
324                 Err(SyntheticSyntax) => continue,
325             };
326             types.push((node.clone(), ty));
327             if let Some(mismatch) = inference_result.type_mismatch_for_expr(expr) {
328                 mismatches.push((node, mismatch));
329             }
330         }
331 
332         // sort ranges for consistency
333         types.sort_by_key(|(node, _)| {
334             let range = node.value.text_range();
335             (range.start(), range.end())
336         });
337         for (node, ty) in &types {
338             let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) {
339                 (self_param.name().unwrap().syntax().text_range(), "self".to_string())
340             } else {
341                 (node.value.text_range(), node.value.text().to_string().replace('\n', " "))
342             };
343             let macro_prefix = if node.file_id != file_id.into() { "!" } else { "" };
344             format_to!(
345                 buf,
346                 "{}{:?} '{}': {}\n",
347                 macro_prefix,
348                 range,
349                 ellipsize(text, 15),
350                 ty.display_test(&db)
351             );
352         }
353         if include_mismatches {
354             mismatches.sort_by_key(|(node, _)| {
355                 let range = node.value.text_range();
356                 (range.start(), range.end())
357             });
358             for (src_ptr, mismatch) in &mismatches {
359                 let range = src_ptr.value.text_range();
360                 let macro_prefix = if src_ptr.file_id != file_id.into() { "!" } else { "" };
361                 format_to!(
362                     buf,
363                     "{}{:?}: expected {}, got {}\n",
364                     macro_prefix,
365                     range,
366                     mismatch.expected.display_test(&db),
367                     mismatch.actual.display_test(&db),
368                 );
369             }
370         }
371     };
372 
373     let module = db.module_for_file(file_id);
374     let def_map = module.def_map(&db);
375 
376     let mut defs: Vec<DefWithBodyId> = Vec::new();
377     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
378     defs.sort_by_key(|def| match def {
379         DefWithBodyId::FunctionId(it) => {
380             let loc = it.lookup(&db);
381             loc.source(&db).value.syntax().text_range().start()
382         }
383         DefWithBodyId::ConstId(it) => {
384             let loc = it.lookup(&db);
385             loc.source(&db).value.syntax().text_range().start()
386         }
387         DefWithBodyId::StaticId(it) => {
388             let loc = it.lookup(&db);
389             loc.source(&db).value.syntax().text_range().start()
390         }
391         DefWithBodyId::VariantId(it) => {
392             let loc = db.lookup_intern_enum(it.parent);
393             loc.source(&db).value.syntax().text_range().start()
394         }
395         DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(),
396     });
397     for def in defs {
398         let (body, source_map) = db.body_with_source_map(def);
399         let infer = db.infer(def);
400         infer_def(infer, body, source_map);
401     }
402 
403     buf.truncate(buf.trim_end().len());
404     buf
405 }
406 
visit_module( db: &TestDB, crate_def_map: &DefMap, module_id: LocalModuleId, cb: &mut dyn FnMut(DefWithBodyId), )407 fn visit_module(
408     db: &TestDB,
409     crate_def_map: &DefMap,
410     module_id: LocalModuleId,
411     cb: &mut dyn FnMut(DefWithBodyId),
412 ) {
413     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
414     for impl_id in crate_def_map[module_id].scope.impls() {
415         let impl_data = db.impl_data(impl_id);
416         for &item in impl_data.items.iter() {
417             match item {
418                 AssocItemId::FunctionId(it) => {
419                     let def = it.into();
420                     cb(def);
421                     let body = db.body(def);
422                     visit_body(db, &body, cb);
423                 }
424                 AssocItemId::ConstId(it) => {
425                     let def = it.into();
426                     cb(def);
427                     let body = db.body(def);
428                     visit_body(db, &body, cb);
429                 }
430                 AssocItemId::TypeAliasId(_) => (),
431             }
432         }
433     }
434 
435     fn visit_scope(
436         db: &TestDB,
437         crate_def_map: &DefMap,
438         scope: &ItemScope,
439         cb: &mut dyn FnMut(DefWithBodyId),
440     ) {
441         for decl in scope.declarations() {
442             match decl {
443                 ModuleDefId::FunctionId(it) => {
444                     let def = it.into();
445                     cb(def);
446                     let body = db.body(def);
447                     visit_body(db, &body, cb);
448                 }
449                 ModuleDefId::ConstId(it) => {
450                     let def = it.into();
451                     cb(def);
452                     let body = db.body(def);
453                     visit_body(db, &body, cb);
454                 }
455                 ModuleDefId::StaticId(it) => {
456                     let def = it.into();
457                     cb(def);
458                     let body = db.body(def);
459                     visit_body(db, &body, cb);
460                 }
461                 ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
462                     db.enum_data(it)
463                         .variants
464                         .iter()
465                         .map(|(id, _)| hir_def::EnumVariantId { parent: it, local_id: id })
466                         .for_each(|it| {
467                             let def = it.into();
468                             cb(def);
469                             let body = db.body(def);
470                             visit_body(db, &body, cb);
471                         });
472                 }
473                 ModuleDefId::TraitId(it) => {
474                     let trait_data = db.trait_data(it);
475                     for &(_, item) in trait_data.items.iter() {
476                         match item {
477                             AssocItemId::FunctionId(it) => cb(it.into()),
478                             AssocItemId::ConstId(it) => cb(it.into()),
479                             AssocItemId::TypeAliasId(_) => (),
480                         }
481                     }
482                 }
483                 ModuleDefId::ModuleId(it) => visit_module(db, crate_def_map, it.local_id, cb),
484                 _ => (),
485             }
486         }
487     }
488 
489     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
490         for (_, def_map) in body.blocks(db) {
491             for (mod_id, _) in def_map.modules() {
492                 visit_module(db, &def_map, mod_id, cb);
493             }
494         }
495     }
496 }
497 
ellipsize(mut text: String, max_len: usize) -> String498 fn ellipsize(mut text: String, max_len: usize) -> String {
499     if text.len() <= max_len {
500         return text;
501     }
502     let ellipsis = "...";
503     let e_len = ellipsis.len();
504     let mut prefix_len = (max_len - e_len) / 2;
505     while !text.is_char_boundary(prefix_len) {
506         prefix_len += 1;
507     }
508     let mut suffix_len = max_len - e_len - prefix_len;
509     while !text.is_char_boundary(text.len() - suffix_len) {
510         suffix_len += 1;
511     }
512     text.replace_range(prefix_len..text.len() - suffix_len, ellipsis);
513     text
514 }
515 
check_infer(ra_fixture: &str, expect: Expect)516 fn check_infer(ra_fixture: &str, expect: Expect) {
517     let mut actual = infer(ra_fixture);
518     actual.push('\n');
519     expect.assert_eq(&actual);
520 }
521 
check_infer_with_mismatches(ra_fixture: &str, expect: Expect)522 fn check_infer_with_mismatches(ra_fixture: &str, expect: Expect) {
523     let mut actual = infer_with_mismatches(ra_fixture, true);
524     actual.push('\n');
525     expect.assert_eq(&actual);
526 }
527 
528 #[test]
salsa_bug()529 fn salsa_bug() {
530     let (mut db, pos) = TestDB::with_position(
531         "
532         //- /lib.rs
533         trait Index {
534             type Output;
535         }
536 
537         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
538 
539         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
540             type Key;
541 
542             fn len(&self) -> usize;
543         }
544 
545         pub trait UnificationStoreMut: UnificationStoreBase {
546             fn push(&mut self, value: Self::Key);
547         }
548 
549         fn main() {
550             let x = 1;
551             x.push(1);$0
552         }
553     ",
554     );
555 
556     let module = db.module_for_file(pos.file_id);
557     let crate_def_map = module.def_map(&db);
558     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
559         db.infer(def);
560     });
561 
562     let new_text = "
563         //- /lib.rs
564         trait Index {
565             type Output;
566         }
567 
568         type Key<S: UnificationStoreBase> = <S as UnificationStoreBase>::Key;
569 
570         pub trait UnificationStoreBase: Index<Output = Key<Self>> {
571             type Key;
572 
573             fn len(&self) -> usize;
574         }
575 
576         pub trait UnificationStoreMut: UnificationStoreBase {
577             fn push(&mut self, value: Self::Key);
578         }
579 
580         fn main() {
581 
582             let x = 1;
583             x.push(1);
584         }
585     ";
586 
587     db.set_file_text(pos.file_id, Arc::from(new_text));
588 
589     let module = db.module_for_file(pos.file_id);
590     let crate_def_map = module.def_map(&db);
591     visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
592         db.infer(def);
593     });
594 }
595