• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! To make attribute macros work reliably when typing, we need to take care to
2 //! fix up syntax errors in the code we're passing to them.
3 use std::mem;
4 
5 use mbe::{SyntheticToken, SyntheticTokenId, TokenMap};
6 use rustc_hash::FxHashMap;
7 use smallvec::SmallVec;
8 use syntax::{
9     ast::{self, AstNode, HasLoopBody},
10     match_ast, SyntaxElement, SyntaxKind, SyntaxNode, TextRange,
11 };
12 use tt::token_id::Subtree;
13 
14 /// The result of calculating fixes for a syntax node -- a bunch of changes
15 /// (appending to and replacing nodes), the information that is needed to
16 /// reverse those changes afterwards, and a token map.
17 #[derive(Debug, Default)]
18 pub(crate) struct SyntaxFixups {
19     pub(crate) append: FxHashMap<SyntaxElement, Vec<SyntheticToken>>,
20     pub(crate) replace: FxHashMap<SyntaxElement, Vec<SyntheticToken>>,
21     pub(crate) undo_info: SyntaxFixupUndoInfo,
22     pub(crate) token_map: TokenMap,
23     pub(crate) next_id: u32,
24 }
25 
26 /// This is the information needed to reverse the fixups.
27 #[derive(Debug, Default, PartialEq, Eq)]
28 pub struct SyntaxFixupUndoInfo {
29     original: Vec<Subtree>,
30 }
31 
32 const EMPTY_ID: SyntheticTokenId = SyntheticTokenId(!0);
33 
fixup_syntax(node: &SyntaxNode) -> SyntaxFixups34 pub(crate) fn fixup_syntax(node: &SyntaxNode) -> SyntaxFixups {
35     let mut append = FxHashMap::<SyntaxElement, _>::default();
36     let mut replace = FxHashMap::<SyntaxElement, _>::default();
37     let mut preorder = node.preorder();
38     let mut original = Vec::new();
39     let mut token_map = TokenMap::default();
40     let mut next_id = 0;
41     while let Some(event) = preorder.next() {
42         let node = match event {
43             syntax::WalkEvent::Enter(node) => node,
44             syntax::WalkEvent::Leave(_) => continue,
45         };
46 
47         if can_handle_error(&node) && has_error_to_handle(&node) {
48             // the node contains an error node, we have to completely replace it by something valid
49             let (original_tree, new_tmap, new_next_id) =
50                 mbe::syntax_node_to_token_tree_with_modifications(
51                     &node,
52                     mem::take(&mut token_map),
53                     next_id,
54                     Default::default(),
55                     Default::default(),
56                 );
57             token_map = new_tmap;
58             next_id = new_next_id;
59             let idx = original.len() as u32;
60             original.push(original_tree);
61             let replacement = SyntheticToken {
62                 kind: SyntaxKind::IDENT,
63                 text: "__ra_fixup".into(),
64                 range: node.text_range(),
65                 id: SyntheticTokenId(idx),
66             };
67             replace.insert(node.clone().into(), vec![replacement]);
68             preorder.skip_subtree();
69             continue;
70         }
71         // In some other situations, we can fix things by just appending some tokens.
72         let end_range = TextRange::empty(node.text_range().end());
73         match_ast! {
74             match node {
75                 ast::FieldExpr(it) => {
76                     if it.name_ref().is_none() {
77                         // incomplete field access: some_expr.|
78                         append.insert(node.clone().into(), vec![
79                             SyntheticToken {
80                                 kind: SyntaxKind::IDENT,
81                                 text: "__ra_fixup".into(),
82                                 range: end_range,
83                                 id: EMPTY_ID,
84                             },
85                         ]);
86                     }
87                 },
88                 ast::ExprStmt(it) => {
89                     if it.semicolon_token().is_none() {
90                         append.insert(node.clone().into(), vec![
91                             SyntheticToken {
92                                 kind: SyntaxKind::SEMICOLON,
93                                 text: ";".into(),
94                                 range: end_range,
95                                 id: EMPTY_ID,
96                             },
97                         ]);
98                     }
99                 },
100                 ast::LetStmt(it) => {
101                     if it.semicolon_token().is_none() {
102                         append.insert(node.clone().into(), vec![
103                             SyntheticToken {
104                                 kind: SyntaxKind::SEMICOLON,
105                                 text: ";".into(),
106                                 range: end_range,
107                                 id: EMPTY_ID,
108                             },
109                         ]);
110                     }
111                 },
112                 ast::IfExpr(it) => {
113                     if it.condition().is_none() {
114                         // insert placeholder token after the if token
115                         let if_token = match it.if_token() {
116                             Some(t) => t,
117                             None => continue,
118                         };
119                         append.insert(if_token.into(), vec![
120                             SyntheticToken {
121                                 kind: SyntaxKind::IDENT,
122                                 text: "__ra_fixup".into(),
123                                 range: end_range,
124                                 id: EMPTY_ID,
125                             },
126                         ]);
127                     }
128                     if it.then_branch().is_none() {
129                         append.insert(node.clone().into(), vec![
130                             SyntheticToken {
131                                 kind: SyntaxKind::L_CURLY,
132                                 text: "{".into(),
133                                 range: end_range,
134                                 id: EMPTY_ID,
135                             },
136                             SyntheticToken {
137                                 kind: SyntaxKind::R_CURLY,
138                                 text: "}".into(),
139                                 range: end_range,
140                                 id: EMPTY_ID,
141                             },
142                         ]);
143                     }
144                 },
145                 ast::WhileExpr(it) => {
146                     if it.condition().is_none() {
147                         // insert placeholder token after the while token
148                         let while_token = match it.while_token() {
149                             Some(t) => t,
150                             None => continue,
151                         };
152                         append.insert(while_token.into(), vec![
153                             SyntheticToken {
154                                 kind: SyntaxKind::IDENT,
155                                 text: "__ra_fixup".into(),
156                                 range: end_range,
157                                 id: EMPTY_ID,
158                             },
159                         ]);
160                     }
161                     if it.loop_body().is_none() {
162                         append.insert(node.clone().into(), vec![
163                             SyntheticToken {
164                                 kind: SyntaxKind::L_CURLY,
165                                 text: "{".into(),
166                                 range: end_range,
167                                 id: EMPTY_ID,
168                             },
169                             SyntheticToken {
170                                 kind: SyntaxKind::R_CURLY,
171                                 text: "}".into(),
172                                 range: end_range,
173                                 id: EMPTY_ID,
174                             },
175                         ]);
176                     }
177                 },
178                 ast::LoopExpr(it) => {
179                     if it.loop_body().is_none() {
180                         append.insert(node.clone().into(), vec![
181                             SyntheticToken {
182                                 kind: SyntaxKind::L_CURLY,
183                                 text: "{".into(),
184                                 range: end_range,
185                                 id: EMPTY_ID,
186                             },
187                             SyntheticToken {
188                                 kind: SyntaxKind::R_CURLY,
189                                 text: "}".into(),
190                                 range: end_range,
191                                 id: EMPTY_ID,
192                             },
193                         ]);
194                     }
195                 },
196                 // FIXME: foo::
197                 ast::MatchExpr(it) => {
198                     if it.expr().is_none() {
199                         let match_token = match it.match_token() {
200                             Some(t) => t,
201                             None => continue
202                         };
203                         append.insert(match_token.into(), vec![
204                             SyntheticToken {
205                                 kind: SyntaxKind::IDENT,
206                                 text: "__ra_fixup".into(),
207                                 range: end_range,
208                                 id: EMPTY_ID
209                             },
210                         ]);
211                     }
212                     if it.match_arm_list().is_none() {
213                         // No match arms
214                         append.insert(node.clone().into(), vec![
215                             SyntheticToken {
216                                 kind: SyntaxKind::L_CURLY,
217                                 text: "{".into(),
218                                 range: end_range,
219                                 id: EMPTY_ID,
220                             },
221                             SyntheticToken {
222                                 kind: SyntaxKind::R_CURLY,
223                                 text: "}".into(),
224                                 range: end_range,
225                                 id: EMPTY_ID,
226                             },
227                         ]);
228                     }
229                 },
230                 ast::ForExpr(it) => {
231                     let for_token = match it.for_token() {
232                         Some(token) => token,
233                         None => continue
234                     };
235 
236                     let [pat, in_token, iter] = [
237                         (SyntaxKind::UNDERSCORE, "_"),
238                         (SyntaxKind::IN_KW, "in"),
239                         (SyntaxKind::IDENT, "__ra_fixup")
240                     ].map(|(kind, text)| SyntheticToken { kind, text: text.into(), range: end_range, id: EMPTY_ID});
241 
242                     if it.pat().is_none() && it.in_token().is_none() && it.iterable().is_none() {
243                         append.insert(for_token.into(), vec![pat, in_token, iter]);
244                     // does something funky -- see test case for_no_pat
245                     } else if it.pat().is_none() {
246                         append.insert(for_token.into(), vec![pat]);
247                     }
248 
249                     if it.loop_body().is_none() {
250                         append.insert(node.clone().into(), vec![
251                             SyntheticToken {
252                                 kind: SyntaxKind::L_CURLY,
253                                 text: "{".into(),
254                                 range: end_range,
255                                 id: EMPTY_ID,
256                             },
257                             SyntheticToken {
258                                 kind: SyntaxKind::R_CURLY,
259                                 text: "}".into(),
260                                 range: end_range,
261                                 id: EMPTY_ID,
262                             },
263                         ]);
264                     }
265                 },
266                 _ => (),
267             }
268         }
269     }
270     SyntaxFixups {
271         append,
272         replace,
273         token_map,
274         next_id,
275         undo_info: SyntaxFixupUndoInfo { original },
276     }
277 }
278 
has_error(node: &SyntaxNode) -> bool279 fn has_error(node: &SyntaxNode) -> bool {
280     node.children().any(|c| c.kind() == SyntaxKind::ERROR)
281 }
282 
can_handle_error(node: &SyntaxNode) -> bool283 fn can_handle_error(node: &SyntaxNode) -> bool {
284     ast::Expr::can_cast(node.kind())
285 }
286 
has_error_to_handle(node: &SyntaxNode) -> bool287 fn has_error_to_handle(node: &SyntaxNode) -> bool {
288     has_error(node) || node.children().any(|c| !can_handle_error(&c) && has_error_to_handle(&c))
289 }
290 
reverse_fixups( tt: &mut Subtree, token_map: &TokenMap, undo_info: &SyntaxFixupUndoInfo, )291 pub(crate) fn reverse_fixups(
292     tt: &mut Subtree,
293     token_map: &TokenMap,
294     undo_info: &SyntaxFixupUndoInfo,
295 ) {
296     let tts = std::mem::take(&mut tt.token_trees);
297     tt.token_trees = tts
298         .into_iter()
299         .filter(|tt| match tt {
300             tt::TokenTree::Leaf(leaf) => {
301                 token_map.synthetic_token_id(*leaf.span()) != Some(EMPTY_ID)
302             }
303             tt::TokenTree::Subtree(st) => {
304                 token_map.synthetic_token_id(st.delimiter.open) != Some(EMPTY_ID)
305             }
306         })
307         .flat_map(|tt| match tt {
308             tt::TokenTree::Subtree(mut tt) => {
309                 reverse_fixups(&mut tt, token_map, undo_info);
310                 SmallVec::from_const([tt.into()])
311             }
312             tt::TokenTree::Leaf(leaf) => {
313                 if let Some(id) = token_map.synthetic_token_id(*leaf.span()) {
314                     let original = undo_info.original[id.0 as usize].clone();
315                     if original.delimiter.kind == tt::DelimiterKind::Invisible {
316                         original.token_trees.into()
317                     } else {
318                         SmallVec::from_const([original.into()])
319                     }
320                 } else {
321                     SmallVec::from_const([leaf.into()])
322                 }
323             }
324         })
325         .collect();
326 }
327 
328 #[cfg(test)]
329 mod tests {
330     use expect_test::{expect, Expect};
331 
332     use crate::tt;
333 
334     use super::reverse_fixups;
335 
336     // The following three functions are only meant to check partial structural equivalence of
337     // `TokenTree`s, see the last assertion in `check()`.
check_leaf_eq(a: &tt::Leaf, b: &tt::Leaf) -> bool338     fn check_leaf_eq(a: &tt::Leaf, b: &tt::Leaf) -> bool {
339         match (a, b) {
340             (tt::Leaf::Literal(a), tt::Leaf::Literal(b)) => a.text == b.text,
341             (tt::Leaf::Punct(a), tt::Leaf::Punct(b)) => a.char == b.char,
342             (tt::Leaf::Ident(a), tt::Leaf::Ident(b)) => a.text == b.text,
343             _ => false,
344         }
345     }
346 
check_subtree_eq(a: &tt::Subtree, b: &tt::Subtree) -> bool347     fn check_subtree_eq(a: &tt::Subtree, b: &tt::Subtree) -> bool {
348         a.delimiter.kind == b.delimiter.kind
349             && a.token_trees.len() == b.token_trees.len()
350             && a.token_trees.iter().zip(&b.token_trees).all(|(a, b)| check_tt_eq(a, b))
351     }
352 
check_tt_eq(a: &tt::TokenTree, b: &tt::TokenTree) -> bool353     fn check_tt_eq(a: &tt::TokenTree, b: &tt::TokenTree) -> bool {
354         match (a, b) {
355             (tt::TokenTree::Leaf(a), tt::TokenTree::Leaf(b)) => check_leaf_eq(a, b),
356             (tt::TokenTree::Subtree(a), tt::TokenTree::Subtree(b)) => check_subtree_eq(a, b),
357             _ => false,
358         }
359     }
360 
361     #[track_caller]
check(ra_fixture: &str, mut expect: Expect)362     fn check(ra_fixture: &str, mut expect: Expect) {
363         let parsed = syntax::SourceFile::parse(ra_fixture);
364         let fixups = super::fixup_syntax(&parsed.syntax_node());
365         let (mut tt, tmap, _) = mbe::syntax_node_to_token_tree_with_modifications(
366             &parsed.syntax_node(),
367             fixups.token_map,
368             fixups.next_id,
369             fixups.replace,
370             fixups.append,
371         );
372 
373         let actual = format!("{tt}\n");
374 
375         expect.indent(false);
376         expect.assert_eq(&actual);
377 
378         // the fixed-up tree should be syntactically valid
379         let (parse, _) = mbe::token_tree_to_syntax_node(&tt, ::mbe::TopEntryPoint::MacroItems);
380         assert!(
381             parse.errors().is_empty(),
382             "parse has syntax errors. parse tree:\n{:#?}",
383             parse.syntax_node()
384         );
385 
386         reverse_fixups(&mut tt, &tmap, &fixups.undo_info);
387 
388         // the fixed-up + reversed version should be equivalent to the original input
389         // modulo token IDs and `Punct`s' spacing.
390         let (original_as_tt, _) = mbe::syntax_node_to_token_tree(&parsed.syntax_node());
391         assert!(
392             check_subtree_eq(&tt, &original_as_tt),
393             "different token tree: {tt:?},\n{original_as_tt:?}"
394         );
395     }
396 
397     #[test]
just_for_token()398     fn just_for_token() {
399         check(
400             r#"
401 fn foo() {
402     for
403 }
404 "#,
405             expect![[r#"
406 fn foo () {for _ in __ra_fixup {}}
407 "#]],
408         )
409     }
410 
411     #[test]
for_no_iter_pattern()412     fn for_no_iter_pattern() {
413         check(
414             r#"
415 fn foo() {
416     for {}
417 }
418 "#,
419             expect![[r#"
420 fn foo () {for _ in __ra_fixup {}}
421 "#]],
422         )
423     }
424 
425     #[test]
for_no_body()426     fn for_no_body() {
427         check(
428             r#"
429 fn foo() {
430     for bar in qux
431 }
432 "#,
433             expect![[r#"
434 fn foo () {for bar in qux {}}
435 "#]],
436         )
437     }
438 
439     // FIXME: https://github.com/rust-lang/rust-analyzer/pull/12937#discussion_r937633695
440     #[test]
for_no_pat()441     fn for_no_pat() {
442         check(
443             r#"
444 fn foo() {
445     for in qux {
446 
447     }
448 }
449 "#,
450             expect![[r#"
451 fn foo () {__ra_fixup}
452 "#]],
453         )
454     }
455 
456     #[test]
match_no_expr_no_arms()457     fn match_no_expr_no_arms() {
458         check(
459             r#"
460 fn foo() {
461     match
462 }
463 "#,
464             expect![[r#"
465 fn foo () {match __ra_fixup {}}
466 "#]],
467         )
468     }
469 
470     #[test]
match_expr_no_arms()471     fn match_expr_no_arms() {
472         check(
473             r#"
474 fn foo() {
475     match x {
476 
477     }
478 }
479 "#,
480             expect![[r#"
481 fn foo () {match x {}}
482 "#]],
483         )
484     }
485 
486     #[test]
match_no_expr()487     fn match_no_expr() {
488         check(
489             r#"
490 fn foo() {
491     match {
492         _ => {}
493     }
494 }
495 "#,
496             expect![[r#"
497 fn foo () {match __ra_fixup {}}
498 "#]],
499         )
500     }
501 
502     #[test]
incomplete_field_expr_1()503     fn incomplete_field_expr_1() {
504         check(
505             r#"
506 fn foo() {
507     a.
508 }
509 "#,
510             expect![[r#"
511 fn foo () {a . __ra_fixup}
512 "#]],
513         )
514     }
515 
516     #[test]
incomplete_field_expr_2()517     fn incomplete_field_expr_2() {
518         check(
519             r#"
520 fn foo() {
521     a.;
522 }
523 "#,
524             expect![[r#"
525 fn foo () {a . __ra_fixup ;}
526 "#]],
527         )
528     }
529 
530     #[test]
incomplete_field_expr_3()531     fn incomplete_field_expr_3() {
532         check(
533             r#"
534 fn foo() {
535     a.;
536     bar();
537 }
538 "#,
539             expect![[r#"
540 fn foo () {a . __ra_fixup ; bar () ;}
541 "#]],
542         )
543     }
544 
545     #[test]
incomplete_let()546     fn incomplete_let() {
547         check(
548             r#"
549 fn foo() {
550     let x = a
551 }
552 "#,
553             expect![[r#"
554 fn foo () {let x = a ;}
555 "#]],
556         )
557     }
558 
559     #[test]
incomplete_field_expr_in_let()560     fn incomplete_field_expr_in_let() {
561         check(
562             r#"
563 fn foo() {
564     let x = a.
565 }
566 "#,
567             expect![[r#"
568 fn foo () {let x = a . __ra_fixup ;}
569 "#]],
570         )
571     }
572 
573     #[test]
field_expr_before_call()574     fn field_expr_before_call() {
575         // another case that easily happens while typing
576         check(
577             r#"
578 fn foo() {
579     a.b
580     bar();
581 }
582 "#,
583             expect![[r#"
584 fn foo () {a . b ; bar () ;}
585 "#]],
586         )
587     }
588 
589     #[test]
extraneous_comma()590     fn extraneous_comma() {
591         check(
592             r#"
593 fn foo() {
594     bar(,);
595 }
596 "#,
597             expect![[r#"
598 fn foo () {__ra_fixup ;}
599 "#]],
600         )
601     }
602 
603     #[test]
fixup_if_1()604     fn fixup_if_1() {
605         check(
606             r#"
607 fn foo() {
608     if a
609 }
610 "#,
611             expect![[r#"
612 fn foo () {if a {}}
613 "#]],
614         )
615     }
616 
617     #[test]
fixup_if_2()618     fn fixup_if_2() {
619         check(
620             r#"
621 fn foo() {
622     if
623 }
624 "#,
625             expect![[r#"
626 fn foo () {if __ra_fixup {}}
627 "#]],
628         )
629     }
630 
631     #[test]
fixup_if_3()632     fn fixup_if_3() {
633         check(
634             r#"
635 fn foo() {
636     if {}
637 }
638 "#,
639             expect![[r#"
640 fn foo () {if __ra_fixup {} {}}
641 "#]],
642         )
643     }
644 
645     #[test]
fixup_while_1()646     fn fixup_while_1() {
647         check(
648             r#"
649 fn foo() {
650     while
651 }
652 "#,
653             expect![[r#"
654 fn foo () {while __ra_fixup {}}
655 "#]],
656         )
657     }
658 
659     #[test]
fixup_while_2()660     fn fixup_while_2() {
661         check(
662             r#"
663 fn foo() {
664     while foo
665 }
666 "#,
667             expect![[r#"
668 fn foo () {while foo {}}
669 "#]],
670         )
671     }
672     #[test]
fixup_while_3()673     fn fixup_while_3() {
674         check(
675             r#"
676 fn foo() {
677     while {}
678 }
679 "#,
680             expect![[r#"
681 fn foo () {while __ra_fixup {}}
682 "#]],
683         )
684     }
685 
686     #[test]
fixup_loop()687     fn fixup_loop() {
688         check(
689             r#"
690 fn foo() {
691     loop
692 }
693 "#,
694             expect![[r#"
695 fn foo () {loop {}}
696 "#]],
697         )
698     }
699 }
700