• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use ide_db::defs::{Definition, NameRefClass};
2 use syntax::{
3     ast::{self, HasName, Name},
4     ted, AstNode, SyntaxNode,
5 };
6 
7 use crate::{
8     assist_context::{AssistContext, Assists},
9     AssistId, AssistKind,
10 };
11 
12 // Assist: convert_match_to_let_else
13 //
14 // Converts let statement with match initializer to let-else statement.
15 //
16 // ```
17 // # //- minicore: option
18 // fn foo(opt: Option<()>) {
19 //     let val$0 = match opt {
20 //         Some(it) => it,
21 //         None => return,
22 //     };
23 // }
24 // ```
25 // ->
26 // ```
27 // fn foo(opt: Option<()>) {
28 //     let Some(val) = opt else { return };
29 // }
30 // ```
convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()>31 pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
32     let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
33     let pat = let_stmt.pat()?;
34     if ctx.offset() > pat.syntax().text_range().end() {
35         return None;
36     }
37 
38     let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
39     let initializer_expr = initializer.expr()?;
40 
41     let Some((extracting_arm, diverging_arm)) = find_arms(ctx, &initializer) else { return None };
42     if extracting_arm.guard().is_some() {
43         cov_mark::hit!(extracting_arm_has_guard);
44         return None;
45     }
46 
47     let diverging_arm_expr = match diverging_arm.expr()? {
48         ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
49             block.to_string()
50         }
51         other => format!("{{ {other} }}"),
52     };
53     let extracting_arm_pat = extracting_arm.pat()?;
54     let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
55 
56     acc.add(
57         AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
58         "Convert match to let-else",
59         let_stmt.syntax().text_range(),
60         |builder| {
61             let extracting_arm_pat =
62                 rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
63             builder.replace(
64                 let_stmt.syntax().text_range(),
65                 format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
66             )
67         },
68     )
69 }
70 
71 // Given a match expression, find extracting and diverging arms.
find_arms( ctx: &AssistContext<'_>, match_expr: &ast::MatchExpr, ) -> Option<(ast::MatchArm, ast::MatchArm)>72 fn find_arms(
73     ctx: &AssistContext<'_>,
74     match_expr: &ast::MatchExpr,
75 ) -> Option<(ast::MatchArm, ast::MatchArm)> {
76     let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
77     if arms.len() != 2 {
78         return None;
79     }
80 
81     let mut extracting = None;
82     let mut diverging = None;
83     for arm in arms {
84         if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
85             diverging = Some(arm);
86         } else {
87             extracting = Some(arm);
88         }
89     }
90 
91     match (extracting, diverging) {
92         (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
93         _ => {
94             cov_mark::hit!(non_diverging_match);
95             None
96         }
97     }
98 }
99 
100 // Given an extracting arm, find the extracted variable.
find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>>101 fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
102     match arm.expr()? {
103         ast::Expr::PathExpr(path) => {
104             let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
105             match NameRefClass::classify(&ctx.sema, &name_ref)? {
106                 NameRefClass::Definition(Definition::Local(local)) => {
107                     let source =
108                         local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
109                     source.collect()
110                 }
111                 _ => None,
112             }
113         }
114         _ => {
115             cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
116             return None;
117         }
118     }
119 }
120 
121 // Rename `extracted` with `binding` in `pat`.
rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode122 fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
123     let syntax = pat.syntax().clone_for_update();
124     let extracted = extracted
125         .iter()
126         .map(|e| syntax.covering_element(e.syntax().text_range()))
127         .collect::<Vec<_>>();
128     for extracted_syntax in extracted {
129         // If `extracted` variable is a record field, we should rename it to `binding`,
130         // otherwise we just need to replace `extracted` with `binding`.
131 
132         if let Some(record_pat_field) =
133             extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
134         {
135             if let Some(name_ref) = record_pat_field.field_name() {
136                 ted::replace(
137                     record_pat_field.syntax(),
138                     ast::make::record_pat_field(
139                         ast::make::name_ref(&name_ref.text()),
140                         binding.clone(),
141                     )
142                     .syntax()
143                     .clone_for_update(),
144                 );
145             }
146         } else {
147             ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
148         }
149     }
150     syntax
151 }
152 
153 #[cfg(test)]
154 mod tests {
155     use crate::tests::{check_assist, check_assist_not_applicable};
156 
157     use super::*;
158 
159     #[test]
should_not_be_applicable_for_non_diverging_match()160     fn should_not_be_applicable_for_non_diverging_match() {
161         cov_mark::check!(non_diverging_match);
162         check_assist_not_applicable(
163             convert_match_to_let_else,
164             r#"
165 //- minicore: option
166 fn foo(opt: Option<()>) {
167     let val$0 = match opt {
168         Some(it) => it,
169         None => (),
170     };
171 }
172 "#,
173         );
174     }
175 
176     #[test]
or_pattern_multiple_binding()177     fn or_pattern_multiple_binding() {
178         check_assist(
179             convert_match_to_let_else,
180             r#"
181 //- minicore: option
182 enum Foo {
183     A(u32),
184     B(u32),
185     C(String),
186 }
187 
188 fn foo(opt: Option<Foo>) -> Result<u32, ()> {
189     let va$0lue = match opt {
190         Some(Foo::A(it) | Foo::B(it)) => it,
191         _ => return Err(()),
192     };
193 }
194     "#,
195             r#"
196 enum Foo {
197     A(u32),
198     B(u32),
199     C(String),
200 }
201 
202 fn foo(opt: Option<Foo>) -> Result<u32, ()> {
203     let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
204 }
205     "#,
206         );
207     }
208 
209     #[test]
should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr()210     fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
211         cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
212         check_assist_not_applicable(
213             convert_match_to_let_else,
214             r#"
215 //- minicore: option
216 fn foo(opt: Option<i32>) {
217     let val$0 = match opt {
218         Some(it) => it + 1,
219         None => return,
220     };
221 }
222 "#,
223         );
224 
225         check_assist_not_applicable(
226             convert_match_to_let_else,
227             r#"
228 //- minicore: option
229 fn foo(opt: Option<()>) {
230     let val$0 = match opt {
231         Some(it) => {
232             let _ = 1 + 1;
233             it
234         },
235         None => return,
236     };
237 }
238 "#,
239         );
240     }
241 
242     #[test]
should_not_be_applicable_if_extracting_arm_has_guard()243     fn should_not_be_applicable_if_extracting_arm_has_guard() {
244         cov_mark::check!(extracting_arm_has_guard);
245         check_assist_not_applicable(
246             convert_match_to_let_else,
247             r#"
248 //- minicore: option
249 fn foo(opt: Option<()>) {
250     let val$0 = match opt {
251         Some(it) if 2 > 1 => it,
252         None => return,
253     };
254 }
255 "#,
256         );
257     }
258 
259     #[test]
basic_pattern()260     fn basic_pattern() {
261         check_assist(
262             convert_match_to_let_else,
263             r#"
264 //- minicore: option
265 fn foo(opt: Option<()>) {
266     let val$0 = match opt {
267         Some(it) => it,
268         None => return,
269     };
270 }
271     "#,
272             r#"
273 fn foo(opt: Option<()>) {
274     let Some(val) = opt else { return };
275 }
276     "#,
277         );
278     }
279 
280     #[test]
keeps_modifiers()281     fn keeps_modifiers() {
282         check_assist(
283             convert_match_to_let_else,
284             r#"
285 //- minicore: option
286 fn foo(opt: Option<()>) {
287     let ref mut val$0 = match opt {
288         Some(it) => it,
289         None => return,
290     };
291 }
292     "#,
293             r#"
294 fn foo(opt: Option<()>) {
295     let Some(ref mut val) = opt else { return };
296 }
297     "#,
298         );
299     }
300 
301     #[test]
nested_pattern()302     fn nested_pattern() {
303         check_assist(
304             convert_match_to_let_else,
305             r#"
306 //- minicore: option, result
307 fn foo(opt: Option<Result<()>>) {
308     let val$0 = match opt {
309         Some(Ok(it)) => it,
310         _ => return,
311     };
312 }
313     "#,
314             r#"
315 fn foo(opt: Option<Result<()>>) {
316     let Some(Ok(val)) = opt else { return };
317 }
318     "#,
319         );
320     }
321 
322     #[test]
works_with_any_diverging_block()323     fn works_with_any_diverging_block() {
324         check_assist(
325             convert_match_to_let_else,
326             r#"
327 //- minicore: option
328 fn foo(opt: Option<()>) {
329     loop {
330         let val$0 = match opt {
331             Some(it) => it,
332             None => break,
333         };
334     }
335 }
336     "#,
337             r#"
338 fn foo(opt: Option<()>) {
339     loop {
340         let Some(val) = opt else { break };
341     }
342 }
343     "#,
344         );
345 
346         check_assist(
347             convert_match_to_let_else,
348             r#"
349 //- minicore: option
350 fn foo(opt: Option<()>) {
351     loop {
352         let val$0 = match opt {
353             Some(it) => it,
354             None => continue,
355         };
356     }
357 }
358     "#,
359             r#"
360 fn foo(opt: Option<()>) {
361     loop {
362         let Some(val) = opt else { continue };
363     }
364 }
365     "#,
366         );
367 
368         check_assist(
369             convert_match_to_let_else,
370             r#"
371 //- minicore: option
372 fn panic() -> ! {}
373 
374 fn foo(opt: Option<()>) {
375     loop {
376         let val$0 = match opt {
377             Some(it) => it,
378             None => panic(),
379         };
380     }
381 }
382     "#,
383             r#"
384 fn panic() -> ! {}
385 
386 fn foo(opt: Option<()>) {
387     loop {
388         let Some(val) = opt else { panic() };
389     }
390 }
391     "#,
392         );
393     }
394 
395     #[test]
struct_pattern()396     fn struct_pattern() {
397         check_assist(
398             convert_match_to_let_else,
399             r#"
400 //- minicore: option
401 struct Point {
402     x: i32,
403     y: i32,
404 }
405 
406 fn foo(opt: Option<Point>) {
407     let val$0 = match opt {
408         Some(Point { x: 0, y }) => y,
409         _ => return,
410     };
411 }
412     "#,
413             r#"
414 struct Point {
415     x: i32,
416     y: i32,
417 }
418 
419 fn foo(opt: Option<Point>) {
420     let Some(Point { x: 0, y: val }) = opt else { return };
421 }
422     "#,
423         );
424     }
425 
426     #[test]
renames_whole_binding()427     fn renames_whole_binding() {
428         check_assist(
429             convert_match_to_let_else,
430             r#"
431 //- minicore: option
432 fn foo(opt: Option<i32>) -> Option<i32> {
433     let val$0 = match opt {
434         it @ Some(42) => it,
435         _ => return None,
436     };
437     val
438 }
439     "#,
440             r#"
441 fn foo(opt: Option<i32>) -> Option<i32> {
442     let val @ Some(42) = opt else { return None };
443     val
444 }
445     "#,
446         );
447     }
448 
449     #[test]
complex_pattern()450     fn complex_pattern() {
451         check_assist(
452             convert_match_to_let_else,
453             r#"
454 //- minicore: option
455 fn f() {
456     let (x, y)$0 = match Some((0, 1)) {
457         Some(it) => it,
458         None => return,
459     };
460 }
461 "#,
462             r#"
463 fn f() {
464     let Some((x, y)) = Some((0, 1)) else { return };
465 }
466 "#,
467         );
468     }
469 
470     #[test]
diverging_block()471     fn diverging_block() {
472         check_assist(
473             convert_match_to_let_else,
474             r#"
475 //- minicore: option
476 fn f() {
477     let x$0 = match Some(()) {
478         Some(it) => it,
479         None => {//comment
480             println!("nope");
481             return
482         },
483     };
484 }
485 "#,
486             r#"
487 fn f() {
488     let Some(x) = Some(()) else {//comment
489             println!("nope");
490             return
491         };
492 }
493 "#,
494         );
495     }
496 }
497