• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! See [`PathTransform`].
2 
3 use crate::helpers::mod_path_to_ast;
4 use either::Either;
5 use hir::{AsAssocItem, HirDisplay, SemanticsScope};
6 use rustc_hash::FxHashMap;
7 use syntax::{
8     ast::{self, AstNode},
9     ted, SyntaxNode,
10 };
11 
12 #[derive(Default)]
13 struct AstSubsts {
14     types_and_consts: Vec<TypeOrConst>,
15     lifetimes: Vec<ast::LifetimeArg>,
16 }
17 
18 enum TypeOrConst {
19     Either(ast::TypeArg), // indistinguishable type or const param
20     Const(ast::ConstArg),
21 }
22 
23 type LifetimeName = String;
24 
25 /// `PathTransform` substitutes path in SyntaxNodes in bulk.
26 ///
27 /// This is mostly useful for IDE code generation. If you paste some existing
28 /// code into a new context (for example, to add method overrides to an `impl`
29 /// block), you generally want to appropriately qualify the names, and sometimes
30 /// you might want to substitute generic parameters as well:
31 ///
32 /// ```
33 /// mod x {
34 ///   pub struct A<V>;
35 ///   pub trait T<U> { fn foo(&self, _: U) -> A<U>; }
36 /// }
37 ///
38 /// mod y {
39 ///   use x::T;
40 ///
41 ///   impl T<()> for () {
42 ///      // If we invoke **Add Missing Members** here, we want to copy-paste `foo`.
43 ///      // But we want a slightly-modified version of it:
44 ///      fn foo(&self, _: ()) -> x::A<()> {}
45 ///   }
46 /// }
47 /// ```
48 pub struct PathTransform<'a> {
49     generic_def: Option<hir::GenericDef>,
50     substs: AstSubsts,
51     target_scope: &'a SemanticsScope<'a>,
52     source_scope: &'a SemanticsScope<'a>,
53 }
54 
55 impl<'a> PathTransform<'a> {
trait_impl( target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, trait_: hir::Trait, impl_: ast::Impl, ) -> PathTransform<'a>56     pub fn trait_impl(
57         target_scope: &'a SemanticsScope<'a>,
58         source_scope: &'a SemanticsScope<'a>,
59         trait_: hir::Trait,
60         impl_: ast::Impl,
61     ) -> PathTransform<'a> {
62         PathTransform {
63             source_scope,
64             target_scope,
65             generic_def: Some(trait_.into()),
66             substs: get_syntactic_substs(impl_).unwrap_or_default(),
67         }
68     }
69 
function_call( target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, function: hir::Function, generic_arg_list: ast::GenericArgList, ) -> PathTransform<'a>70     pub fn function_call(
71         target_scope: &'a SemanticsScope<'a>,
72         source_scope: &'a SemanticsScope<'a>,
73         function: hir::Function,
74         generic_arg_list: ast::GenericArgList,
75     ) -> PathTransform<'a> {
76         PathTransform {
77             source_scope,
78             target_scope,
79             generic_def: Some(function.into()),
80             substs: get_type_args_from_arg_list(generic_arg_list).unwrap_or_default(),
81         }
82     }
83 
generic_transformation( target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>, ) -> PathTransform<'a>84     pub fn generic_transformation(
85         target_scope: &'a SemanticsScope<'a>,
86         source_scope: &'a SemanticsScope<'a>,
87     ) -> PathTransform<'a> {
88         PathTransform {
89             source_scope,
90             target_scope,
91             generic_def: None,
92             substs: AstSubsts::default(),
93         }
94     }
95 
apply(&self, syntax: &SyntaxNode)96     pub fn apply(&self, syntax: &SyntaxNode) {
97         self.build_ctx().apply(syntax)
98     }
99 
apply_all<'b>(&self, nodes: impl IntoIterator<Item = &'b SyntaxNode>)100     pub fn apply_all<'b>(&self, nodes: impl IntoIterator<Item = &'b SyntaxNode>) {
101         let ctx = self.build_ctx();
102         for node in nodes {
103             ctx.apply(node);
104         }
105     }
106 
build_ctx(&self) -> Ctx<'a>107     fn build_ctx(&self) -> Ctx<'a> {
108         let db = self.source_scope.db;
109         let target_module = self.target_scope.module();
110         let source_module = self.source_scope.module();
111         let skip = match self.generic_def {
112             // this is a trait impl, so we need to skip the first type parameter (i.e. Self) -- this is a bit hacky
113             Some(hir::GenericDef::Trait(_)) => 1,
114             _ => 0,
115         };
116         let mut type_substs: FxHashMap<hir::TypeParam, ast::Type> = Default::default();
117         let mut const_substs: FxHashMap<hir::ConstParam, SyntaxNode> = Default::default();
118         let mut default_types: Vec<hir::TypeParam> = Default::default();
119         self.generic_def
120             .into_iter()
121             .flat_map(|it| it.type_params(db))
122             .skip(skip)
123             // The actual list of trait type parameters may be longer than the one
124             // used in the `impl` block due to trailing default type parameters.
125             // For that case we extend the `substs` with an empty iterator so we
126             // can still hit those trailing values and check if they actually have
127             // a default type. If they do, go for that type from `hir` to `ast` so
128             // the resulting change can be applied correctly.
129             .zip(self.substs.types_and_consts.iter().map(Some).chain(std::iter::repeat(None)))
130             .for_each(|(k, v)| match (k.split(db), v) {
131                 (Either::Right(k), Some(TypeOrConst::Either(v))) => {
132                     if let Some(ty) = v.ty() {
133                         type_substs.insert(k, ty.clone());
134                     }
135                 }
136                 (Either::Right(k), None) => {
137                     if let Some(default) = k.default(db) {
138                         if let Some(default) =
139                             &default.display_source_code(db, source_module.into(), false).ok()
140                         {
141                             type_substs.insert(k, ast::make::ty(default).clone_for_update());
142                             default_types.push(k);
143                         }
144                     }
145                 }
146                 (Either::Left(k), Some(TypeOrConst::Either(v))) => {
147                     if let Some(ty) = v.ty() {
148                         const_substs.insert(k, ty.syntax().clone());
149                     }
150                 }
151                 (Either::Left(k), Some(TypeOrConst::Const(v))) => {
152                     if let Some(expr) = v.expr() {
153                         // FIXME: expressions in curly brackets can cause ambiguity after insertion
154                         // (e.g. `N * 2` -> `{1 + 1} * 2`; it's unclear whether `{1 + 1}`
155                         // is a standalone statement or a part of another expresson)
156                         // and sometimes require slight modifications; see
157                         // https://doc.rust-lang.org/reference/statements.html#expression-statements
158                         const_substs.insert(k, expr.syntax().clone());
159                     }
160                 }
161                 (Either::Left(_), None) => (), // FIXME: get default const value
162                 _ => (),                       // ignore mismatching params
163             });
164         let lifetime_substs: FxHashMap<_, _> = self
165             .generic_def
166             .into_iter()
167             .flat_map(|it| it.lifetime_params(db))
168             .zip(self.substs.lifetimes.clone())
169             .filter_map(|(k, v)| Some((k.name(db).display(db.upcast()).to_string(), v.lifetime()?)))
170             .collect();
171         let ctx = Ctx {
172             type_substs,
173             const_substs,
174             lifetime_substs,
175             target_module,
176             source_scope: self.source_scope,
177         };
178         ctx.transform_default_type_substs(default_types);
179         ctx
180     }
181 }
182 
183 struct Ctx<'a> {
184     type_substs: FxHashMap<hir::TypeParam, ast::Type>,
185     const_substs: FxHashMap<hir::ConstParam, SyntaxNode>,
186     lifetime_substs: FxHashMap<LifetimeName, ast::Lifetime>,
187     target_module: hir::Module,
188     source_scope: &'a SemanticsScope<'a>,
189 }
190 
postorder(item: &SyntaxNode) -> impl Iterator<Item = SyntaxNode>191 fn postorder(item: &SyntaxNode) -> impl Iterator<Item = SyntaxNode> {
192     item.preorder().filter_map(|event| match event {
193         syntax::WalkEvent::Enter(_) => None,
194         syntax::WalkEvent::Leave(node) => Some(node),
195     })
196 }
197 
198 impl<'a> Ctx<'a> {
apply(&self, item: &SyntaxNode)199     fn apply(&self, item: &SyntaxNode) {
200         // `transform_path` may update a node's parent and that would break the
201         // tree traversal. Thus all paths in the tree are collected into a vec
202         // so that such operation is safe.
203         let paths = postorder(item).filter_map(ast::Path::cast).collect::<Vec<_>>();
204         for path in paths {
205             self.transform_path(path);
206         }
207 
208         postorder(item).filter_map(ast::Lifetime::cast).for_each(|lifetime| {
209             if let Some(subst) = self.lifetime_substs.get(&lifetime.syntax().text().to_string()) {
210                 ted::replace(lifetime.syntax(), subst.clone_subtree().clone_for_update().syntax());
211             }
212         });
213     }
214 
transform_default_type_substs(&self, default_types: Vec<hir::TypeParam>)215     fn transform_default_type_substs(&self, default_types: Vec<hir::TypeParam>) {
216         for k in default_types {
217             let v = self.type_substs.get(&k).unwrap();
218             // `transform_path` may update a node's parent and that would break the
219             // tree traversal. Thus all paths in the tree are collected into a vec
220             // so that such operation is safe.
221             let paths = postorder(&v.syntax()).filter_map(ast::Path::cast).collect::<Vec<_>>();
222             for path in paths {
223                 self.transform_path(path);
224             }
225         }
226     }
227 
transform_path(&self, path: ast::Path) -> Option<()>228     fn transform_path(&self, path: ast::Path) -> Option<()> {
229         if path.qualifier().is_some() {
230             return None;
231         }
232         if path.segment().map_or(false, |s| {
233             s.param_list().is_some() || (s.self_token().is_some() && path.parent_path().is_none())
234         }) {
235             // don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
236             // don't try to qualify sole `self` either, they are usually locals, but are returned as modules due to namespace clashing
237             return None;
238         }
239 
240         let resolution = self.source_scope.speculative_resolve(&path)?;
241 
242         match resolution {
243             hir::PathResolution::TypeParam(tp) => {
244                 if let Some(subst) = self.type_substs.get(&tp) {
245                     let parent = path.syntax().parent()?;
246                     if let Some(parent) = ast::Path::cast(parent.clone()) {
247                         // Path inside path means that there is an associated
248                         // type/constant on the type parameter. It is necessary
249                         // to fully qualify the type with `as Trait`. Even
250                         // though it might be unnecessary if `subst` is generic
251                         // type, always fully qualifying the path is safer
252                         // because of potential clash of associated types from
253                         // multiple traits
254 
255                         let trait_ref = find_trait_for_assoc_item(
256                             self.source_scope,
257                             tp,
258                             parent.segment()?.name_ref()?,
259                         )
260                         .and_then(|trait_ref| {
261                             let found_path = self.target_module.find_use_path(
262                                 self.source_scope.db.upcast(),
263                                 hir::ModuleDef::Trait(trait_ref),
264                                 false,
265                             )?;
266                             match ast::make::ty_path(mod_path_to_ast(&found_path)) {
267                                 ast::Type::PathType(path_ty) => Some(path_ty),
268                                 _ => None,
269                             }
270                         });
271 
272                         let segment = ast::make::path_segment_ty(subst.clone(), trait_ref);
273                         let qualified =
274                             ast::make::path_from_segments(std::iter::once(segment), false);
275                         ted::replace(path.syntax(), qualified.clone_for_update().syntax());
276                     } else if let Some(path_ty) = ast::PathType::cast(parent) {
277                         ted::replace(
278                             path_ty.syntax(),
279                             subst.clone_subtree().clone_for_update().syntax(),
280                         );
281                     } else {
282                         ted::replace(
283                             path.syntax(),
284                             subst.clone_subtree().clone_for_update().syntax(),
285                         );
286                     }
287                 }
288             }
289             hir::PathResolution::Def(def) if def.as_assoc_item(self.source_scope.db).is_none() => {
290                 if let hir::ModuleDef::Trait(_) = def {
291                     if matches!(path.segment()?.kind()?, ast::PathSegmentKind::Type { .. }) {
292                         // `speculative_resolve` resolves segments like `<T as
293                         // Trait>` into `Trait`, but just the trait name should
294                         // not be used as the replacement of the original
295                         // segment.
296                         return None;
297                     }
298                 }
299 
300                 let found_path =
301                     self.target_module.find_use_path(self.source_scope.db.upcast(), def, false)?;
302                 let res = mod_path_to_ast(&found_path).clone_for_update();
303                 if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) {
304                     if let Some(segment) = res.segment() {
305                         let old = segment.get_or_create_generic_arg_list();
306                         ted::replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
307                     }
308                 }
309                 ted::replace(path.syntax(), res.syntax())
310             }
311             hir::PathResolution::ConstParam(cp) => {
312                 if let Some(subst) = self.const_substs.get(&cp) {
313                     ted::replace(path.syntax(), subst.clone_subtree().clone_for_update());
314                 }
315             }
316             hir::PathResolution::Local(_)
317             | hir::PathResolution::SelfType(_)
318             | hir::PathResolution::Def(_)
319             | hir::PathResolution::BuiltinAttr(_)
320             | hir::PathResolution::ToolModule(_)
321             | hir::PathResolution::DeriveHelper(_) => (),
322         }
323         Some(())
324     }
325 }
326 
327 // FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
328 // trait ref, and then go from the types in the substs back to the syntax).
get_syntactic_substs(impl_def: ast::Impl) -> Option<AstSubsts>329 fn get_syntactic_substs(impl_def: ast::Impl) -> Option<AstSubsts> {
330     let target_trait = impl_def.trait_()?;
331     let path_type = match target_trait {
332         ast::Type::PathType(path) => path,
333         _ => return None,
334     };
335     let generic_arg_list = path_type.path()?.segment()?.generic_arg_list()?;
336 
337     get_type_args_from_arg_list(generic_arg_list)
338 }
339 
get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<AstSubsts>340 fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<AstSubsts> {
341     let mut result = AstSubsts::default();
342     generic_arg_list.generic_args().for_each(|generic_arg| match generic_arg {
343         // Const params are marked as consts on definition only,
344         // being passed to the trait they are indistguishable from type params;
345         // anyway, we don't really need to distinguish them here.
346         ast::GenericArg::TypeArg(type_arg) => {
347             result.types_and_consts.push(TypeOrConst::Either(type_arg))
348         }
349         // Some const values are recognized correctly.
350         ast::GenericArg::ConstArg(const_arg) => {
351             result.types_and_consts.push(TypeOrConst::Const(const_arg));
352         }
353         ast::GenericArg::LifetimeArg(l_arg) => result.lifetimes.push(l_arg),
354         _ => (),
355     });
356 
357     Some(result)
358 }
359 
find_trait_for_assoc_item( scope: &SemanticsScope<'_>, type_param: hir::TypeParam, assoc_item: ast::NameRef, ) -> Option<hir::Trait>360 fn find_trait_for_assoc_item(
361     scope: &SemanticsScope<'_>,
362     type_param: hir::TypeParam,
363     assoc_item: ast::NameRef,
364 ) -> Option<hir::Trait> {
365     let db = scope.db;
366     let trait_bounds = type_param.trait_bounds(db);
367 
368     let assoc_item_name = assoc_item.text();
369 
370     for trait_ in trait_bounds {
371         let names = trait_.items(db).into_iter().filter_map(|item| match item {
372             hir::AssocItem::TypeAlias(ta) => Some(ta.name(db)),
373             hir::AssocItem::Const(cst) => cst.name(db),
374             _ => None,
375         });
376 
377         for name in names {
378             if assoc_item_name.as_str() == name.as_text()?.as_str() {
379                 // It is fine to return the first match because in case of
380                 // multiple possibilities, the exact trait must be disambiguated
381                 // in the definition of trait being implemented, so this search
382                 // should not be needed.
383                 return Some(trait_);
384             }
385         }
386     }
387 
388     None
389 }
390