• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Structural editing for ast.
2 
3 use std::iter::{empty, successors};
4 
5 use parser::{SyntaxKind, T};
6 use rowan::SyntaxElement;
7 
8 use crate::{
9     algo::{self, neighbor},
10     ast::{self, edit::IndentLevel, make, HasGenericParams},
11     ted::{self, Position},
12     AstNode, AstToken, Direction,
13     SyntaxKind::{ATTR, COMMENT, WHITESPACE},
14     SyntaxNode, SyntaxToken,
15 };
16 
17 use super::HasName;
18 
19 pub trait GenericParamsOwnerEdit: ast::HasGenericParams {
get_or_create_generic_param_list(&self) -> ast::GenericParamList20     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
get_or_create_where_clause(&self) -> ast::WhereClause21     fn get_or_create_where_clause(&self) -> ast::WhereClause;
22 }
23 
24 impl GenericParamsOwnerEdit for ast::Fn {
get_or_create_generic_param_list(&self) -> ast::GenericParamList25     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
26         match self.generic_param_list() {
27             Some(it) => it,
28             None => {
29                 let position = if let Some(name) = self.name() {
30                     Position::after(name.syntax)
31                 } else if let Some(fn_token) = self.fn_token() {
32                     Position::after(fn_token)
33                 } else if let Some(param_list) = self.param_list() {
34                     Position::before(param_list.syntax)
35                 } else {
36                     Position::last_child_of(self.syntax())
37                 };
38                 create_generic_param_list(position)
39             }
40         }
41     }
42 
get_or_create_where_clause(&self) -> ast::WhereClause43     fn get_or_create_where_clause(&self) -> ast::WhereClause {
44         if self.where_clause().is_none() {
45             let position = if let Some(ty) = self.ret_type() {
46                 Position::after(ty.syntax())
47             } else if let Some(param_list) = self.param_list() {
48                 Position::after(param_list.syntax())
49             } else {
50                 Position::last_child_of(self.syntax())
51             };
52             create_where_clause(position);
53         }
54         self.where_clause().unwrap()
55     }
56 }
57 
58 impl GenericParamsOwnerEdit for ast::Impl {
get_or_create_generic_param_list(&self) -> ast::GenericParamList59     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
60         match self.generic_param_list() {
61             Some(it) => it,
62             None => {
63                 let position = match self.impl_token() {
64                     Some(imp_token) => Position::after(imp_token),
65                     None => Position::last_child_of(self.syntax()),
66                 };
67                 create_generic_param_list(position)
68             }
69         }
70     }
71 
get_or_create_where_clause(&self) -> ast::WhereClause72     fn get_or_create_where_clause(&self) -> ast::WhereClause {
73         if self.where_clause().is_none() {
74             let position = match self.assoc_item_list() {
75                 Some(items) => Position::before(items.syntax()),
76                 None => Position::last_child_of(self.syntax()),
77             };
78             create_where_clause(position);
79         }
80         self.where_clause().unwrap()
81     }
82 }
83 
84 impl GenericParamsOwnerEdit for ast::Trait {
get_or_create_generic_param_list(&self) -> ast::GenericParamList85     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
86         match self.generic_param_list() {
87             Some(it) => it,
88             None => {
89                 let position = if let Some(name) = self.name() {
90                     Position::after(name.syntax)
91                 } else if let Some(trait_token) = self.trait_token() {
92                     Position::after(trait_token)
93                 } else {
94                     Position::last_child_of(self.syntax())
95                 };
96                 create_generic_param_list(position)
97             }
98         }
99     }
100 
get_or_create_where_clause(&self) -> ast::WhereClause101     fn get_or_create_where_clause(&self) -> ast::WhereClause {
102         if self.where_clause().is_none() {
103             let position = match self.assoc_item_list() {
104                 Some(items) => Position::before(items.syntax()),
105                 None => Position::last_child_of(self.syntax()),
106             };
107             create_where_clause(position);
108         }
109         self.where_clause().unwrap()
110     }
111 }
112 
113 impl GenericParamsOwnerEdit for ast::Struct {
get_or_create_generic_param_list(&self) -> ast::GenericParamList114     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
115         match self.generic_param_list() {
116             Some(it) => it,
117             None => {
118                 let position = if let Some(name) = self.name() {
119                     Position::after(name.syntax)
120                 } else if let Some(struct_token) = self.struct_token() {
121                     Position::after(struct_token)
122                 } else {
123                     Position::last_child_of(self.syntax())
124                 };
125                 create_generic_param_list(position)
126             }
127         }
128     }
129 
get_or_create_where_clause(&self) -> ast::WhereClause130     fn get_or_create_where_clause(&self) -> ast::WhereClause {
131         if self.where_clause().is_none() {
132             let tfl = self.field_list().and_then(|fl| match fl {
133                 ast::FieldList::RecordFieldList(_) => None,
134                 ast::FieldList::TupleFieldList(it) => Some(it),
135             });
136             let position = if let Some(tfl) = tfl {
137                 Position::after(tfl.syntax())
138             } else if let Some(gpl) = self.generic_param_list() {
139                 Position::after(gpl.syntax())
140             } else if let Some(name) = self.name() {
141                 Position::after(name.syntax())
142             } else {
143                 Position::last_child_of(self.syntax())
144             };
145             create_where_clause(position);
146         }
147         self.where_clause().unwrap()
148     }
149 }
150 
151 impl GenericParamsOwnerEdit for ast::Enum {
get_or_create_generic_param_list(&self) -> ast::GenericParamList152     fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
153         match self.generic_param_list() {
154             Some(it) => it,
155             None => {
156                 let position = if let Some(name) = self.name() {
157                     Position::after(name.syntax)
158                 } else if let Some(enum_token) = self.enum_token() {
159                     Position::after(enum_token)
160                 } else {
161                     Position::last_child_of(self.syntax())
162                 };
163                 create_generic_param_list(position)
164             }
165         }
166     }
167 
get_or_create_where_clause(&self) -> ast::WhereClause168     fn get_or_create_where_clause(&self) -> ast::WhereClause {
169         if self.where_clause().is_none() {
170             let position = if let Some(gpl) = self.generic_param_list() {
171                 Position::after(gpl.syntax())
172             } else if let Some(name) = self.name() {
173                 Position::after(name.syntax())
174             } else {
175                 Position::last_child_of(self.syntax())
176             };
177             create_where_clause(position);
178         }
179         self.where_clause().unwrap()
180     }
181 }
182 
create_where_clause(position: Position)183 fn create_where_clause(position: Position) {
184     let where_clause = make::where_clause(empty()).clone_for_update();
185     ted::insert(position, where_clause.syntax());
186 }
187 
create_generic_param_list(position: Position) -> ast::GenericParamList188 fn create_generic_param_list(position: Position) -> ast::GenericParamList {
189     let gpl = make::generic_param_list(empty()).clone_for_update();
190     ted::insert_raw(position, gpl.syntax());
191     gpl
192 }
193 
194 pub trait AttrsOwnerEdit: ast::HasAttrs {
remove_attrs_and_docs(&self)195     fn remove_attrs_and_docs(&self) {
196         remove_attrs_and_docs(self.syntax());
197 
198         fn remove_attrs_and_docs(node: &SyntaxNode) {
199             let mut remove_next_ws = false;
200             for child in node.children_with_tokens() {
201                 match child.kind() {
202                     ATTR | COMMENT => {
203                         remove_next_ws = true;
204                         child.detach();
205                         continue;
206                     }
207                     WHITESPACE if remove_next_ws => {
208                         child.detach();
209                     }
210                     _ => (),
211                 }
212                 remove_next_ws = false;
213             }
214         }
215     }
216 }
217 
218 impl<T: ast::HasAttrs> AttrsOwnerEdit for T {}
219 
220 impl ast::GenericParamList {
add_generic_param(&self, generic_param: ast::GenericParam)221     pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
222         match self.generic_params().last() {
223             Some(last_param) => {
224                 let position = Position::after(last_param.syntax());
225                 let elements = vec![
226                     make::token(T![,]).into(),
227                     make::tokens::single_space().into(),
228                     generic_param.syntax().clone().into(),
229                 ];
230                 ted::insert_all(position, elements);
231             }
232             None => {
233                 let after_l_angle = Position::after(self.l_angle_token().unwrap());
234                 ted::insert(after_l_angle, generic_param.syntax());
235             }
236         }
237     }
238 
239     /// Removes the existing generic param
remove_generic_param(&self, generic_param: ast::GenericParam)240     pub fn remove_generic_param(&self, generic_param: ast::GenericParam) {
241         if let Some(previous) = generic_param.syntax().prev_sibling() {
242             if let Some(next_token) = previous.next_sibling_or_token() {
243                 ted::remove_all(next_token..=generic_param.syntax().clone().into());
244             }
245         } else if let Some(next) = generic_param.syntax().next_sibling() {
246             if let Some(next_token) = next.prev_sibling_or_token() {
247                 ted::remove_all(generic_param.syntax().clone().into()..=next_token);
248             }
249         } else {
250             ted::remove(generic_param.syntax());
251         }
252     }
253 
254     /// Constructs a matching [`ast::GenericArgList`]
to_generic_args(&self) -> ast::GenericArgList255     pub fn to_generic_args(&self) -> ast::GenericArgList {
256         let args = self.generic_params().filter_map(|param| match param {
257             ast::GenericParam::LifetimeParam(it) => {
258                 Some(ast::GenericArg::LifetimeArg(make::lifetime_arg(it.lifetime()?)))
259             }
260             ast::GenericParam::TypeParam(it) => {
261                 Some(ast::GenericArg::TypeArg(make::type_arg(make::ext::ty_name(it.name()?))))
262             }
263             ast::GenericParam::ConstParam(it) => {
264                 // Name-only const params get parsed as `TypeArg`s
265                 Some(ast::GenericArg::TypeArg(make::type_arg(make::ext::ty_name(it.name()?))))
266             }
267         });
268 
269         make::generic_arg_list(args)
270     }
271 }
272 
273 impl ast::WhereClause {
add_predicate(&self, predicate: ast::WherePred)274     pub fn add_predicate(&self, predicate: ast::WherePred) {
275         if let Some(pred) = self.predicates().last() {
276             if !pred.syntax().siblings_with_tokens(Direction::Next).any(|it| it.kind() == T![,]) {
277                 ted::append_child_raw(self.syntax(), make::token(T![,]));
278             }
279         }
280         ted::append_child(self.syntax(), predicate.syntax());
281     }
282 }
283 
284 impl ast::TypeParam {
remove_default(&self)285     pub fn remove_default(&self) {
286         if let Some((eq, last)) = self
287             .syntax()
288             .children_with_tokens()
289             .find(|it| it.kind() == T![=])
290             .zip(self.syntax().last_child_or_token())
291         {
292             ted::remove_all(eq..=last);
293 
294             // remove any trailing ws
295             if let Some(last) = self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) {
296                 last.detach();
297             }
298         }
299     }
300 }
301 
302 impl ast::ConstParam {
remove_default(&self)303     pub fn remove_default(&self) {
304         if let Some((eq, last)) = self
305             .syntax()
306             .children_with_tokens()
307             .find(|it| it.kind() == T![=])
308             .zip(self.syntax().last_child_or_token())
309         {
310             ted::remove_all(eq..=last);
311 
312             // remove any trailing ws
313             if let Some(last) = self.syntax().last_token().filter(|it| it.kind() == WHITESPACE) {
314                 last.detach();
315             }
316         }
317     }
318 }
319 
320 pub trait Removable: AstNode {
remove(&self)321     fn remove(&self);
322 }
323 
324 impl Removable for ast::TypeBoundList {
remove(&self)325     fn remove(&self) {
326         match self.syntax().siblings_with_tokens(Direction::Prev).find(|it| it.kind() == T![:]) {
327             Some(colon) => ted::remove_all(colon..=self.syntax().clone().into()),
328             None => ted::remove(self.syntax()),
329         }
330     }
331 }
332 
333 impl ast::PathSegment {
get_or_create_generic_arg_list(&self) -> ast::GenericArgList334     pub fn get_or_create_generic_arg_list(&self) -> ast::GenericArgList {
335         if self.generic_arg_list().is_none() {
336             let arg_list = make::generic_arg_list(empty()).clone_for_update();
337             ted::append_child(self.syntax(), arg_list.syntax());
338         }
339         self.generic_arg_list().unwrap()
340     }
341 }
342 
343 impl Removable for ast::UseTree {
remove(&self)344     fn remove(&self) {
345         for dir in [Direction::Next, Direction::Prev] {
346             if let Some(next_use_tree) = neighbor(self, dir) {
347                 let separators = self
348                     .syntax()
349                     .siblings_with_tokens(dir)
350                     .skip(1)
351                     .take_while(|it| it.as_node() != Some(next_use_tree.syntax()));
352                 ted::remove_all_iter(separators);
353                 break;
354             }
355         }
356         ted::remove(self.syntax());
357     }
358 }
359 
360 impl ast::UseTree {
get_or_create_use_tree_list(&self) -> ast::UseTreeList361     pub fn get_or_create_use_tree_list(&self) -> ast::UseTreeList {
362         match self.use_tree_list() {
363             Some(it) => it,
364             None => {
365                 let position = Position::last_child_of(self.syntax());
366                 let use_tree_list = make::use_tree_list(empty()).clone_for_update();
367                 let mut elements = Vec::with_capacity(2);
368                 if self.coloncolon_token().is_none() {
369                     elements.push(make::token(T![::]).into());
370                 }
371                 elements.push(use_tree_list.syntax().clone().into());
372                 ted::insert_all_raw(position, elements);
373                 use_tree_list
374             }
375         }
376     }
377 
378     /// Splits off the given prefix, making it the path component of the use tree,
379     /// appending the rest of the path to all UseTreeList items.
380     ///
381     /// # Examples
382     ///
383     /// `prefix$0::suffix` -> `prefix::{suffix}`
384     ///
385     /// `prefix$0` -> `prefix::{self}`
386     ///
387     /// `prefix$0::*` -> `prefix::{*}`
split_prefix(&self, prefix: &ast::Path)388     pub fn split_prefix(&self, prefix: &ast::Path) {
389         debug_assert_eq!(self.path(), Some(prefix.top_path()));
390         let path = self.path().unwrap();
391         if &path == prefix && self.use_tree_list().is_none() {
392             if self.star_token().is_some() {
393                 // path$0::* -> *
394                 self.coloncolon_token().map(ted::remove);
395                 ted::remove(prefix.syntax());
396             } else {
397                 // path$0 -> self
398                 let self_suffix =
399                     make::path_unqualified(make::path_segment_self()).clone_for_update();
400                 ted::replace(path.syntax(), self_suffix.syntax());
401             }
402         } else if split_path_prefix(prefix).is_none() {
403             return;
404         }
405         // At this point, prefix path is detached; _self_ use tree has suffix path.
406         // Next, transform 'suffix' use tree into 'prefix::{suffix}'
407         let subtree = self.clone_subtree().clone_for_update();
408         ted::remove_all_iter(self.syntax().children_with_tokens());
409         ted::insert(Position::first_child_of(self.syntax()), prefix.syntax());
410         self.get_or_create_use_tree_list().add_use_tree(subtree);
411 
412         fn split_path_prefix(prefix: &ast::Path) -> Option<()> {
413             let parent = prefix.parent_path()?;
414             let segment = parent.segment()?;
415             if algo::has_errors(segment.syntax()) {
416                 return None;
417             }
418             for p in successors(parent.parent_path(), |it| it.parent_path()) {
419                 p.segment()?;
420             }
421             prefix.parent_path().and_then(|p| p.coloncolon_token()).map(ted::remove);
422             ted::remove(prefix.syntax());
423             Some(())
424         }
425     }
426 }
427 
428 impl ast::UseTreeList {
add_use_tree(&self, use_tree: ast::UseTree)429     pub fn add_use_tree(&self, use_tree: ast::UseTree) {
430         let (position, elements) = match self.use_trees().last() {
431             Some(last_tree) => (
432                 Position::after(last_tree.syntax()),
433                 vec![
434                     make::token(T![,]).into(),
435                     make::tokens::single_space().into(),
436                     use_tree.syntax.into(),
437                 ],
438             ),
439             None => {
440                 let position = match self.l_curly_token() {
441                     Some(l_curly) => Position::after(l_curly),
442                     None => Position::last_child_of(self.syntax()),
443                 };
444                 (position, vec![use_tree.syntax.into()])
445             }
446         };
447         ted::insert_all_raw(position, elements);
448     }
449 }
450 
451 impl Removable for ast::Use {
remove(&self)452     fn remove(&self) {
453         let next_ws = self
454             .syntax()
455             .next_sibling_or_token()
456             .and_then(|it| it.into_token())
457             .and_then(ast::Whitespace::cast);
458         if let Some(next_ws) = next_ws {
459             let ws_text = next_ws.syntax().text();
460             if let Some(rest) = ws_text.strip_prefix('\n') {
461                 if rest.is_empty() {
462                     ted::remove(next_ws.syntax());
463                 } else {
464                     ted::replace(next_ws.syntax(), make::tokens::whitespace(rest));
465                 }
466             }
467         }
468         ted::remove(self.syntax());
469     }
470 }
471 
472 impl ast::Impl {
get_or_create_assoc_item_list(&self) -> ast::AssocItemList473     pub fn get_or_create_assoc_item_list(&self) -> ast::AssocItemList {
474         if self.assoc_item_list().is_none() {
475             let assoc_item_list = make::assoc_item_list().clone_for_update();
476             ted::append_child(self.syntax(), assoc_item_list.syntax());
477         }
478         self.assoc_item_list().unwrap()
479     }
480 }
481 
482 impl ast::AssocItemList {
483     /// Attention! This function does align the first line of `item` with respect to `self`,
484     /// but it does _not_ change indentation of other lines (if any).
add_item(&self, item: ast::AssocItem)485     pub fn add_item(&self, item: ast::AssocItem) {
486         let (indent, position, whitespace) = match self.assoc_items().last() {
487             Some(last_item) => (
488                 IndentLevel::from_node(last_item.syntax()),
489                 Position::after(last_item.syntax()),
490                 "\n\n",
491             ),
492             None => match self.l_curly_token() {
493                 Some(l_curly) => {
494                     normalize_ws_between_braces(self.syntax());
495                     (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n")
496                 }
497                 None => (IndentLevel::single(), Position::last_child_of(self.syntax()), "\n"),
498             },
499         };
500         let elements: Vec<SyntaxElement<_>> = vec![
501             make::tokens::whitespace(&format!("{whitespace}{indent}")).into(),
502             item.syntax().clone().into(),
503         ];
504         ted::insert_all(position, elements);
505     }
506 }
507 
508 impl ast::Fn {
get_or_create_body(&self) -> ast::BlockExpr509     pub fn get_or_create_body(&self) -> ast::BlockExpr {
510         if self.body().is_none() {
511             let body = make::ext::empty_block_expr().clone_for_update();
512             match self.semicolon_token() {
513                 Some(semi) => {
514                     ted::replace(semi, body.syntax());
515                     ted::insert(Position::before(body.syntax), make::tokens::single_space());
516                 }
517                 None => ted::append_child(self.syntax(), body.syntax()),
518             }
519         }
520         self.body().unwrap()
521     }
522 }
523 
524 impl Removable for ast::MatchArm {
remove(&self)525     fn remove(&self) {
526         if let Some(sibling) = self.syntax().prev_sibling_or_token() {
527             if sibling.kind() == SyntaxKind::WHITESPACE {
528                 ted::remove(sibling);
529             }
530         }
531         if let Some(sibling) = self.syntax().next_sibling_or_token() {
532             if sibling.kind() == T![,] {
533                 ted::remove(sibling);
534             }
535         }
536         ted::remove(self.syntax());
537     }
538 }
539 
540 impl ast::MatchArmList {
add_arm(&self, arm: ast::MatchArm)541     pub fn add_arm(&self, arm: ast::MatchArm) {
542         normalize_ws_between_braces(self.syntax());
543         let mut elements = Vec::new();
544         let position = match self.arms().last() {
545             Some(last_arm) => {
546                 if needs_comma(&last_arm) {
547                     ted::append_child(last_arm.syntax(), make::token(SyntaxKind::COMMA));
548                 }
549                 Position::after(last_arm.syntax().clone())
550             }
551             None => match self.l_curly_token() {
552                 Some(it) => Position::after(it),
553                 None => Position::last_child_of(self.syntax()),
554             },
555         };
556         let indent = IndentLevel::from_node(self.syntax()) + 1;
557         elements.push(make::tokens::whitespace(&format!("\n{indent}")).into());
558         elements.push(arm.syntax().clone().into());
559         if needs_comma(&arm) {
560             ted::append_child(arm.syntax(), make::token(SyntaxKind::COMMA));
561         }
562         ted::insert_all(position, elements);
563 
564         fn needs_comma(arm: &ast::MatchArm) -> bool {
565             arm.expr().map_or(false, |e| !e.is_block_like()) && arm.comma_token().is_none()
566         }
567     }
568 }
569 
570 impl ast::RecordExprFieldList {
add_field(&self, field: ast::RecordExprField)571     pub fn add_field(&self, field: ast::RecordExprField) {
572         let is_multiline = self.syntax().text().contains_char('\n');
573         let whitespace = if is_multiline {
574             let indent = IndentLevel::from_node(self.syntax()) + 1;
575             make::tokens::whitespace(&format!("\n{indent}"))
576         } else {
577             make::tokens::single_space()
578         };
579 
580         if is_multiline {
581             normalize_ws_between_braces(self.syntax());
582         }
583 
584         let position = match self.fields().last() {
585             Some(last_field) => {
586                 let comma = get_or_insert_comma_after(last_field.syntax());
587                 Position::after(comma)
588             }
589             None => match self.l_curly_token() {
590                 Some(it) => Position::after(it),
591                 None => Position::last_child_of(self.syntax()),
592             },
593         };
594 
595         ted::insert_all(position, vec![whitespace.into(), field.syntax().clone().into()]);
596         if is_multiline {
597             ted::insert(Position::after(field.syntax()), ast::make::token(T![,]));
598         }
599     }
600 }
601 
602 impl ast::RecordExprField {
603     /// This will either replace the initializer, or in the case that this is a shorthand convert
604     /// the initializer into the name ref and insert the expr as the new initializer.
replace_expr(&self, expr: ast::Expr)605     pub fn replace_expr(&self, expr: ast::Expr) {
606         if self.name_ref().is_some() {
607             match self.expr() {
608                 Some(prev) => ted::replace(prev.syntax(), expr.syntax()),
609                 None => ted::append_child(self.syntax(), expr.syntax()),
610             }
611             return;
612         }
613         // this is a shorthand
614         if let Some(ast::Expr::PathExpr(path_expr)) = self.expr() {
615             if let Some(path) = path_expr.path() {
616                 if let Some(name_ref) = path.as_single_name_ref() {
617                     path_expr.syntax().detach();
618                     let children = vec![
619                         name_ref.syntax().clone().into(),
620                         ast::make::token(T![:]).into(),
621                         ast::make::tokens::single_space().into(),
622                         expr.syntax().clone().into(),
623                     ];
624                     ted::insert_all_raw(Position::last_child_of(self.syntax()), children);
625                 }
626             }
627         }
628     }
629 }
630 
631 impl ast::RecordPatFieldList {
add_field(&self, field: ast::RecordPatField)632     pub fn add_field(&self, field: ast::RecordPatField) {
633         let is_multiline = self.syntax().text().contains_char('\n');
634         let whitespace = if is_multiline {
635             let indent = IndentLevel::from_node(self.syntax()) + 1;
636             make::tokens::whitespace(&format!("\n{indent}"))
637         } else {
638             make::tokens::single_space()
639         };
640 
641         if is_multiline {
642             normalize_ws_between_braces(self.syntax());
643         }
644 
645         let position = match self.fields().last() {
646             Some(last_field) => {
647                 let syntax = last_field.syntax();
648                 let comma = get_or_insert_comma_after(syntax);
649                 Position::after(comma)
650             }
651             None => match self.l_curly_token() {
652                 Some(it) => Position::after(it),
653                 None => Position::last_child_of(self.syntax()),
654             },
655         };
656 
657         ted::insert_all(position, vec![whitespace.into(), field.syntax().clone().into()]);
658         if is_multiline {
659             ted::insert(Position::after(field.syntax()), ast::make::token(T![,]));
660         }
661     }
662 }
663 
get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken664 fn get_or_insert_comma_after(syntax: &SyntaxNode) -> SyntaxToken {
665     match syntax
666         .siblings_with_tokens(Direction::Next)
667         .filter_map(|it| it.into_token())
668         .find(|it| it.kind() == T![,])
669     {
670         Some(it) => it,
671         None => {
672             let comma = ast::make::token(T![,]);
673             ted::insert(Position::after(syntax), &comma);
674             comma
675         }
676     }
677 }
678 
679 impl ast::StmtList {
push_front(&self, statement: ast::Stmt)680     pub fn push_front(&self, statement: ast::Stmt) {
681         ted::insert(Position::after(self.l_curly_token().unwrap()), statement.syntax());
682     }
683 }
684 
685 impl ast::VariantList {
add_variant(&self, variant: ast::Variant)686     pub fn add_variant(&self, variant: ast::Variant) {
687         let (indent, position) = match self.variants().last() {
688             Some(last_item) => (
689                 IndentLevel::from_node(last_item.syntax()),
690                 Position::after(get_or_insert_comma_after(last_item.syntax())),
691             ),
692             None => match self.l_curly_token() {
693                 Some(l_curly) => {
694                     normalize_ws_between_braces(self.syntax());
695                     (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly))
696                 }
697                 None => (IndentLevel::single(), Position::last_child_of(self.syntax())),
698             },
699         };
700         let elements: Vec<SyntaxElement<_>> = vec![
701             make::tokens::whitespace(&format!("{}{indent}", "\n")).into(),
702             variant.syntax().clone().into(),
703             ast::make::token(T![,]).into(),
704         ];
705         ted::insert_all(position, elements);
706     }
707 }
708 
normalize_ws_between_braces(node: &SyntaxNode) -> Option<()>709 fn normalize_ws_between_braces(node: &SyntaxNode) -> Option<()> {
710     let l = node
711         .children_with_tokens()
712         .filter_map(|it| it.into_token())
713         .find(|it| it.kind() == T!['{'])?;
714     let r = node
715         .children_with_tokens()
716         .filter_map(|it| it.into_token())
717         .find(|it| it.kind() == T!['}'])?;
718 
719     let indent = IndentLevel::from_node(node);
720 
721     match l.next_sibling_or_token() {
722         Some(ws) if ws.kind() == SyntaxKind::WHITESPACE => {
723             if ws.next_sibling_or_token()?.into_token()? == r {
724                 ted::replace(ws, make::tokens::whitespace(&format!("\n{indent}")));
725             }
726         }
727         Some(ws) if ws.kind() == T!['}'] => {
728             ted::insert(Position::after(l), make::tokens::whitespace(&format!("\n{indent}")));
729         }
730         _ => (),
731     }
732     Some(())
733 }
734 
735 pub trait Indent: AstNode + Clone + Sized {
indent_level(&self) -> IndentLevel736     fn indent_level(&self) -> IndentLevel {
737         IndentLevel::from_node(self.syntax())
738     }
indent(&self, by: IndentLevel)739     fn indent(&self, by: IndentLevel) {
740         by.increase_indent(self.syntax());
741     }
dedent(&self, by: IndentLevel)742     fn dedent(&self, by: IndentLevel) {
743         by.decrease_indent(self.syntax());
744     }
reindent_to(&self, target_level: IndentLevel)745     fn reindent_to(&self, target_level: IndentLevel) {
746         let current_level = IndentLevel::from_node(self.syntax());
747         self.dedent(current_level);
748         self.indent(target_level);
749     }
750 }
751 
752 impl<N: AstNode + Clone> Indent for N {}
753 
754 #[cfg(test)]
755 mod tests {
756     use std::fmt;
757 
758     use stdx::trim_indent;
759     use test_utils::assert_eq_text;
760 
761     use crate::SourceFile;
762 
763     use super::*;
764 
ast_mut_from_text<N: AstNode>(text: &str) -> N765     fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
766         let parse = SourceFile::parse(text);
767         parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
768     }
769 
770     #[test]
test_create_generic_param_list()771     fn test_create_generic_param_list() {
772         fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
773             let gpl_owner = ast_mut_from_text::<N>(before);
774             gpl_owner.get_or_create_generic_param_list();
775             assert_eq!(gpl_owner.to_string(), after);
776         }
777 
778         check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
779         check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
780 
781         check_create_gpl::<ast::Impl>("impl", "impl<>");
782         check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
783         check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
784 
785         check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
786         check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
787 
788         check_create_gpl::<ast::Struct>("struct A", "struct A<>");
789         check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
790         check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
791         check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
792 
793         check_create_gpl::<ast::Enum>("enum E", "enum E<>");
794         check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");
795     }
796 
797     #[test]
test_increase_indent()798     fn test_increase_indent() {
799         let arm_list = ast_mut_from_text::<ast::Fn>(
800             "fn foo() {
801     ;
802     ;
803 }",
804         );
805         arm_list.indent(IndentLevel(2));
806         assert_eq!(
807             arm_list.to_string(),
808             "fn foo() {
809             ;
810             ;
811         }",
812         );
813     }
814 
815     #[test]
add_variant_to_empty_enum()816     fn add_variant_to_empty_enum() {
817         let variant = make::variant(make::name("Bar"), None).clone_for_update();
818 
819         check_add_variant(
820             r#"
821 enum Foo {}
822 "#,
823             r#"
824 enum Foo {
825     Bar,
826 }
827 "#,
828             variant,
829         );
830     }
831 
832     #[test]
add_variant_to_non_empty_enum()833     fn add_variant_to_non_empty_enum() {
834         let variant = make::variant(make::name("Baz"), None).clone_for_update();
835 
836         check_add_variant(
837             r#"
838 enum Foo {
839     Bar,
840 }
841 "#,
842             r#"
843 enum Foo {
844     Bar,
845     Baz,
846 }
847 "#,
848             variant,
849         );
850     }
851 
852     #[test]
add_variant_with_tuple_field_list()853     fn add_variant_with_tuple_field_list() {
854         let variant = make::variant(
855             make::name("Baz"),
856             Some(ast::FieldList::TupleFieldList(make::tuple_field_list(std::iter::once(
857                 make::tuple_field(None, make::ty("bool")),
858             )))),
859         )
860         .clone_for_update();
861 
862         check_add_variant(
863             r#"
864 enum Foo {
865     Bar,
866 }
867 "#,
868             r#"
869 enum Foo {
870     Bar,
871     Baz(bool),
872 }
873 "#,
874             variant,
875         );
876     }
877 
878     #[test]
add_variant_with_record_field_list()879     fn add_variant_with_record_field_list() {
880         let variant = make::variant(
881             make::name("Baz"),
882             Some(ast::FieldList::RecordFieldList(make::record_field_list(std::iter::once(
883                 make::record_field(None, make::name("x"), make::ty("bool")),
884             )))),
885         )
886         .clone_for_update();
887 
888         check_add_variant(
889             r#"
890 enum Foo {
891     Bar,
892 }
893 "#,
894             r#"
895 enum Foo {
896     Bar,
897     Baz { x: bool },
898 }
899 "#,
900             variant,
901         );
902     }
903 
check_add_variant(before: &str, expected: &str, variant: ast::Variant)904     fn check_add_variant(before: &str, expected: &str, variant: ast::Variant) {
905         let enum_ = ast_mut_from_text::<ast::Enum>(before);
906         enum_.variant_list().map(|it| it.add_variant(variant));
907         let after = enum_.to_string();
908         assert_eq_text!(&trim_indent(expected.trim()), &trim_indent(after.trim()));
909     }
910 }
911