1 use ide_db::defs::{Definition, NameRefClass};
2 use syntax::{
3 ast::{self, HasName, Name},
4 ted, AstNode, SyntaxNode,
5 };
6
7 use crate::{
8 assist_context::{AssistContext, Assists},
9 AssistId, AssistKind,
10 };
11
12 // Assist: convert_match_to_let_else
13 //
14 // Converts let statement with match initializer to let-else statement.
15 //
16 // ```
17 // # //- minicore: option
18 // fn foo(opt: Option<()>) {
19 // let val$0 = match opt {
20 // Some(it) => it,
21 // None => return,
22 // };
23 // }
24 // ```
25 // ->
26 // ```
27 // fn foo(opt: Option<()>) {
28 // let Some(val) = opt else { return };
29 // }
30 // ```
convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()>31 pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
32 let let_stmt: ast::LetStmt = ctx.find_node_at_offset()?;
33 let pat = let_stmt.pat()?;
34 if ctx.offset() > pat.syntax().text_range().end() {
35 return None;
36 }
37
38 let Some(ast::Expr::MatchExpr(initializer)) = let_stmt.initializer() else { return None };
39 let initializer_expr = initializer.expr()?;
40
41 let Some((extracting_arm, diverging_arm)) = find_arms(ctx, &initializer) else { return None };
42 if extracting_arm.guard().is_some() {
43 cov_mark::hit!(extracting_arm_has_guard);
44 return None;
45 }
46
47 let diverging_arm_expr = match diverging_arm.expr()? {
48 ast::Expr::BlockExpr(block) if block.modifier().is_none() && block.label().is_none() => {
49 block.to_string()
50 }
51 other => format!("{{ {other} }}"),
52 };
53 let extracting_arm_pat = extracting_arm.pat()?;
54 let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
55
56 acc.add(
57 AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
58 "Convert match to let-else",
59 let_stmt.syntax().text_range(),
60 |builder| {
61 let extracting_arm_pat =
62 rename_variable(&extracting_arm_pat, &extracted_variable_positions, pat);
63 builder.replace(
64 let_stmt.syntax().text_range(),
65 format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
66 )
67 },
68 )
69 }
70
71 // Given a match expression, find extracting and diverging arms.
find_arms( ctx: &AssistContext<'_>, match_expr: &ast::MatchExpr, ) -> Option<(ast::MatchArm, ast::MatchArm)>72 fn find_arms(
73 ctx: &AssistContext<'_>,
74 match_expr: &ast::MatchExpr,
75 ) -> Option<(ast::MatchArm, ast::MatchArm)> {
76 let arms = match_expr.match_arm_list()?.arms().collect::<Vec<_>>();
77 if arms.len() != 2 {
78 return None;
79 }
80
81 let mut extracting = None;
82 let mut diverging = None;
83 for arm in arms {
84 if ctx.sema.type_of_expr(&arm.expr()?)?.original().is_never() {
85 diverging = Some(arm);
86 } else {
87 extracting = Some(arm);
88 }
89 }
90
91 match (extracting, diverging) {
92 (Some(extracting), Some(diverging)) => Some((extracting, diverging)),
93 _ => {
94 cov_mark::hit!(non_diverging_match);
95 None
96 }
97 }
98 }
99
100 // Given an extracting arm, find the extracted variable.
find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>>101 fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
102 match arm.expr()? {
103 ast::Expr::PathExpr(path) => {
104 let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
105 match NameRefClass::classify(&ctx.sema, &name_ref)? {
106 NameRefClass::Definition(Definition::Local(local)) => {
107 let source =
108 local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
109 source.collect()
110 }
111 _ => None,
112 }
113 }
114 _ => {
115 cov_mark::hit!(extracting_arm_is_not_an_identity_expr);
116 return None;
117 }
118 }
119 }
120
121 // Rename `extracted` with `binding` in `pat`.
rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode122 fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
123 let syntax = pat.syntax().clone_for_update();
124 let extracted = extracted
125 .iter()
126 .map(|e| syntax.covering_element(e.syntax().text_range()))
127 .collect::<Vec<_>>();
128 for extracted_syntax in extracted {
129 // If `extracted` variable is a record field, we should rename it to `binding`,
130 // otherwise we just need to replace `extracted` with `binding`.
131
132 if let Some(record_pat_field) =
133 extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
134 {
135 if let Some(name_ref) = record_pat_field.field_name() {
136 ted::replace(
137 record_pat_field.syntax(),
138 ast::make::record_pat_field(
139 ast::make::name_ref(&name_ref.text()),
140 binding.clone(),
141 )
142 .syntax()
143 .clone_for_update(),
144 );
145 }
146 } else {
147 ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
148 }
149 }
150 syntax
151 }
152
153 #[cfg(test)]
154 mod tests {
155 use crate::tests::{check_assist, check_assist_not_applicable};
156
157 use super::*;
158
159 #[test]
should_not_be_applicable_for_non_diverging_match()160 fn should_not_be_applicable_for_non_diverging_match() {
161 cov_mark::check!(non_diverging_match);
162 check_assist_not_applicable(
163 convert_match_to_let_else,
164 r#"
165 //- minicore: option
166 fn foo(opt: Option<()>) {
167 let val$0 = match opt {
168 Some(it) => it,
169 None => (),
170 };
171 }
172 "#,
173 );
174 }
175
176 #[test]
or_pattern_multiple_binding()177 fn or_pattern_multiple_binding() {
178 check_assist(
179 convert_match_to_let_else,
180 r#"
181 //- minicore: option
182 enum Foo {
183 A(u32),
184 B(u32),
185 C(String),
186 }
187
188 fn foo(opt: Option<Foo>) -> Result<u32, ()> {
189 let va$0lue = match opt {
190 Some(Foo::A(it) | Foo::B(it)) => it,
191 _ => return Err(()),
192 };
193 }
194 "#,
195 r#"
196 enum Foo {
197 A(u32),
198 B(u32),
199 C(String),
200 }
201
202 fn foo(opt: Option<Foo>) -> Result<u32, ()> {
203 let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
204 }
205 "#,
206 );
207 }
208
209 #[test]
should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr()210 fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
211 cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);
212 check_assist_not_applicable(
213 convert_match_to_let_else,
214 r#"
215 //- minicore: option
216 fn foo(opt: Option<i32>) {
217 let val$0 = match opt {
218 Some(it) => it + 1,
219 None => return,
220 };
221 }
222 "#,
223 );
224
225 check_assist_not_applicable(
226 convert_match_to_let_else,
227 r#"
228 //- minicore: option
229 fn foo(opt: Option<()>) {
230 let val$0 = match opt {
231 Some(it) => {
232 let _ = 1 + 1;
233 it
234 },
235 None => return,
236 };
237 }
238 "#,
239 );
240 }
241
242 #[test]
should_not_be_applicable_if_extracting_arm_has_guard()243 fn should_not_be_applicable_if_extracting_arm_has_guard() {
244 cov_mark::check!(extracting_arm_has_guard);
245 check_assist_not_applicable(
246 convert_match_to_let_else,
247 r#"
248 //- minicore: option
249 fn foo(opt: Option<()>) {
250 let val$0 = match opt {
251 Some(it) if 2 > 1 => it,
252 None => return,
253 };
254 }
255 "#,
256 );
257 }
258
259 #[test]
basic_pattern()260 fn basic_pattern() {
261 check_assist(
262 convert_match_to_let_else,
263 r#"
264 //- minicore: option
265 fn foo(opt: Option<()>) {
266 let val$0 = match opt {
267 Some(it) => it,
268 None => return,
269 };
270 }
271 "#,
272 r#"
273 fn foo(opt: Option<()>) {
274 let Some(val) = opt else { return };
275 }
276 "#,
277 );
278 }
279
280 #[test]
keeps_modifiers()281 fn keeps_modifiers() {
282 check_assist(
283 convert_match_to_let_else,
284 r#"
285 //- minicore: option
286 fn foo(opt: Option<()>) {
287 let ref mut val$0 = match opt {
288 Some(it) => it,
289 None => return,
290 };
291 }
292 "#,
293 r#"
294 fn foo(opt: Option<()>) {
295 let Some(ref mut val) = opt else { return };
296 }
297 "#,
298 );
299 }
300
301 #[test]
nested_pattern()302 fn nested_pattern() {
303 check_assist(
304 convert_match_to_let_else,
305 r#"
306 //- minicore: option, result
307 fn foo(opt: Option<Result<()>>) {
308 let val$0 = match opt {
309 Some(Ok(it)) => it,
310 _ => return,
311 };
312 }
313 "#,
314 r#"
315 fn foo(opt: Option<Result<()>>) {
316 let Some(Ok(val)) = opt else { return };
317 }
318 "#,
319 );
320 }
321
322 #[test]
works_with_any_diverging_block()323 fn works_with_any_diverging_block() {
324 check_assist(
325 convert_match_to_let_else,
326 r#"
327 //- minicore: option
328 fn foo(opt: Option<()>) {
329 loop {
330 let val$0 = match opt {
331 Some(it) => it,
332 None => break,
333 };
334 }
335 }
336 "#,
337 r#"
338 fn foo(opt: Option<()>) {
339 loop {
340 let Some(val) = opt else { break };
341 }
342 }
343 "#,
344 );
345
346 check_assist(
347 convert_match_to_let_else,
348 r#"
349 //- minicore: option
350 fn foo(opt: Option<()>) {
351 loop {
352 let val$0 = match opt {
353 Some(it) => it,
354 None => continue,
355 };
356 }
357 }
358 "#,
359 r#"
360 fn foo(opt: Option<()>) {
361 loop {
362 let Some(val) = opt else { continue };
363 }
364 }
365 "#,
366 );
367
368 check_assist(
369 convert_match_to_let_else,
370 r#"
371 //- minicore: option
372 fn panic() -> ! {}
373
374 fn foo(opt: Option<()>) {
375 loop {
376 let val$0 = match opt {
377 Some(it) => it,
378 None => panic(),
379 };
380 }
381 }
382 "#,
383 r#"
384 fn panic() -> ! {}
385
386 fn foo(opt: Option<()>) {
387 loop {
388 let Some(val) = opt else { panic() };
389 }
390 }
391 "#,
392 );
393 }
394
395 #[test]
struct_pattern()396 fn struct_pattern() {
397 check_assist(
398 convert_match_to_let_else,
399 r#"
400 //- minicore: option
401 struct Point {
402 x: i32,
403 y: i32,
404 }
405
406 fn foo(opt: Option<Point>) {
407 let val$0 = match opt {
408 Some(Point { x: 0, y }) => y,
409 _ => return,
410 };
411 }
412 "#,
413 r#"
414 struct Point {
415 x: i32,
416 y: i32,
417 }
418
419 fn foo(opt: Option<Point>) {
420 let Some(Point { x: 0, y: val }) = opt else { return };
421 }
422 "#,
423 );
424 }
425
426 #[test]
renames_whole_binding()427 fn renames_whole_binding() {
428 check_assist(
429 convert_match_to_let_else,
430 r#"
431 //- minicore: option
432 fn foo(opt: Option<i32>) -> Option<i32> {
433 let val$0 = match opt {
434 it @ Some(42) => it,
435 _ => return None,
436 };
437 val
438 }
439 "#,
440 r#"
441 fn foo(opt: Option<i32>) -> Option<i32> {
442 let val @ Some(42) = opt else { return None };
443 val
444 }
445 "#,
446 );
447 }
448
449 #[test]
complex_pattern()450 fn complex_pattern() {
451 check_assist(
452 convert_match_to_let_else,
453 r#"
454 //- minicore: option
455 fn f() {
456 let (x, y)$0 = match Some((0, 1)) {
457 Some(it) => it,
458 None => return,
459 };
460 }
461 "#,
462 r#"
463 fn f() {
464 let Some((x, y)) = Some((0, 1)) else { return };
465 }
466 "#,
467 );
468 }
469
470 #[test]
diverging_block()471 fn diverging_block() {
472 check_assist(
473 convert_match_to_let_else,
474 r#"
475 //- minicore: option
476 fn f() {
477 let x$0 = match Some(()) {
478 Some(it) => it,
479 None => {//comment
480 println!("nope");
481 return
482 },
483 };
484 }
485 "#,
486 r#"
487 fn f() {
488 let Some(x) = Some(()) else {//comment
489 println!("nope");
490 return
491 };
492 }
493 "#,
494 );
495 }
496 }
497