• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
19 
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23 
24 #include "utils/hash_map.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "frontend/optimizer/anf_visitor.h"
30 #include "ir/func_graph.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/tensor.h"
33 #include "frontend/operator/ops.h"
34 #include "abstract/abstract_value.h"
35 #include "include/common/utils/utils.h"
36 #include "pipeline/jit/ps/pipeline.h"
37 
38 namespace mindspore {
39 namespace opt {
40 namespace irpass {
41 class ReplaceApplicator : public AnfVisitor {
42  public:
operator()43   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
44     if (!IsValueNode<FuncGraph>(node)) {
45       return nullptr;
46     }
47     auto fg = GetValueNode<FuncGraphPtr>(node);
48     if (NoInline(fg)) {
49       return nullptr;
50     }
51 
52     auto out = fg->output();
53     MS_EXCEPTION_IF_NULL(out);
54     if (!out->isa<CNode>()) {
55       return nullptr;
56     }
57 
58     auto &inputs = out->cast<CNodePtr>()->inputs();
59     auto params = fg->parameters();
60 
61     // Exclude first elements of inputs which is fn.
62     auto input_size = inputs.size();
63     auto param_size = params.size();
64     if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
65                                                  std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
66       auto inner = inputs[0];
67       if (IsValueNode<Primitive>(inner)) {
68         return inner;
69       }
70       if (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr) {
71         const auto &inner_fg = GetValueNode<FuncGraphPtr>(inner);
72         MS_EXCEPTION_IF_NULL(inner_fg);
73         bool is_recursive = (inner_fg->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE) ? false : inner_fg->recursive());
74         if (is_recursive) {
75           return nullptr;
76         }
77         return inner;
78       }
79     }
80 
81     return nullptr;
82   }
83 
NoInline(const FuncGraphPtr & fg)84   bool NoInline(const FuncGraphPtr &fg) const {
85     if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
86         fg->stub() || *(fg->indirect())) {
87       return true;
88     }
89     // Defer inlining in the case of pipeline.
90     auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
91     if (fg->stage() != -1 && stage_num > 1) {
92       return true;
93     }
94     // Defer inlining for:
95     // 1. The func_graph which is set recomputed.
96     // 2. The k graph whose primal is set non-recomputed when enable graph reuse.
97     auto context = MsContext::GetInstance();
98     MS_EXCEPTION_IF_NULL(context);
99     const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
100     return fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
101            (cell_reuse &&
102             (fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)));
103   }
104 };
105 
106 class InlinerBase;
107 using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
108 
109 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
110 
111 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
112 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
113 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
114 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
115 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
IsForceInline(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)116 bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
117   return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
118 }
119 
120 // {G, Xs}
121 class InlinerBase : public AnfVisitor {
122  public:
123   explicit InlinerBase(const std::vector<std::vector<CriterionFuncType>> &criterions, bool use_move = true)
use_move_(use_move)124       : use_move_(use_move), criterions_(criterions) {}
125   ~InlinerBase() override = default;
operator()126   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
127     auto cnode = dyn_cast<CNode>(node);
128     if (cnode == nullptr || cnode->size() < 1) {
129       return nullptr;
130     }
131 
132     // Check if no recursive flag was set in top graph.
133     CheckNoRecursive(optimizer);
134 
135     auto &inputs = cnode->inputs();
136     // G
137     auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
138     if (!CheckFlag(fg)) {
139       return nullptr;
140     }
141 
142     Reset();
143 
144     // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
145     // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
146     // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
147     bool is_match = ApplyCriterions(node, fg);
148     if (!is_match) {
149       return nullptr;
150     }
151 
152     std::vector<AnfNodePtr> args;
153     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
154     // Compare size to avoid the case that the function has default value after grad.
155     // for which after renormalize, the function default value will be an input
156     if (fg->parameters().size() != args.size()) {
157       return nullptr;
158     }
159 
160     if (IsForceInline(this, fg, node)) {
161       if (IsUniqueUse(nullptr, fg, nullptr)) {
162         return InlineMove(node, fg, args, inputs);
163       }
164       return InlineClone(fg, node->func_graph(), args, cnode);
165     }
166 
167     if (IsUniqueUse(nullptr, fg, nullptr)) {
168       // For the single used fg, including non-after and after not matched above,
169       // we move the whole fg nodes.
170       auto res_node = InlineForUniqueUse(node, fg, args, inputs);
171       if (res_node != nullptr) {
172         return res_node;
173       }
174     } else {
175       // We don't expand the middle multiple used after block, except the last one.
176       if (GraphHasBranch(fg)) {
177         return nullptr;
178       }
179       // Check if parameters' changed for the first met branch calling.
180       if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
181         auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
182         if (param_simplified_caller != nullptr) {
183           return param_simplified_caller;
184         }
185       }
186     }
187     // Or, just make a clone for not single used fg.
188     auto res = InlineClone(fg, node->func_graph(), args, cnode);
189     return res;
190   }
191 
CheckFlag(const FuncGraphPtr & fg)192   bool CheckFlag(const FuncGraphPtr &fg) const {
193     if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
194         fg->stub()) {
195       return false;
196     }
197     // Defer inlining in the case of pipeline.
198     auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
199     if (fg->stage() != -1 && stage_num > 1) {
200       return false;
201     }
202     // Defer inlining for:
203     // 1. The func_graph which is set recomputed.
204     // 2. The k graph whose primal is set non-recomputed when enable graph reuse.
205     auto context = MsContext::GetInstance();
206     MS_EXCEPTION_IF_NULL(context);
207     const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
208     if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
209         (cell_reuse &&
210          (fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)))) {
211       return false;
212     }
213     return true;
214   }
215 
IsRecursive(const FuncGraphPtr & fg)216   bool IsRecursive(const FuncGraphPtr &fg) {
217     // The user guarantees that fg has no recursive.
218     if (no_recursive_) {
219       return false;
220     }
221 
222     if (!is_checked_) {
223       is_checked_ = true;
224       is_recursive_ = (fg->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE) ? false : fg->recursive());
225     }
226     return is_recursive_;
227   }
228 
no_recursive()229   bool no_recursive() const { return no_recursive_; }
230 
231  private:
InlineMove(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)232   AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
233                         const std::vector<AnfNodePtr> &inputs) const {
234     auto mng = fg->manager();
235     MS_EXCEPTION_IF_NULL(mng);
236     ReplaceParams(mng, args, fg);
237     auto out_node = fg->output();
238     mng->MoveAllCNodeDropGraph(fg, node->func_graph(), node, inputs[0]->scope(), true);
239     return out_node;
240   }
241 
InlineForUniqueUse(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)242   AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
243                                 const std::vector<AnfNodePtr> &inputs) const {
244     if (use_move_) {
245       return InlineMove(node, fg, args, inputs);
246     }
247 
248     // The other branch calling the last after block.
249     if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
250       // Check if parameters' changed.
251       auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
252       if (param_simplified_caller != nullptr) {
253         return param_simplified_caller;
254       }
255     }
256     return nullptr;
257   }
258 
ApplyCriterions(const AnfNodePtr & node,const FuncGraphPtr & fg)259   bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
260     bool is_match = false;
261     for (auto &criterions : criterions_) {  // Each 'criterion group' in criterions_.
262       is_match = true;
263       for (auto &criterion : criterions) {  // Each criterion in 'criterion group'.
264         if (!criterion(this, fg, node)) {
265           is_match = false;
266           break;
267         }
268       }
269       if (is_match) {
270         break;
271       }
272     }
273     return is_match;
274   }
275 
ReplaceParams(const FuncGraphManagerPtr & mng,const std::vector<AnfNodePtr> & new_params,const FuncGraphPtr & fg)276   void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
277                      const FuncGraphPtr &fg) const {
278     auto params = fg->parameters();
279     auto old_size = params.size();
280     constexpr auto print_deep = 10;
281     if (old_size != new_params.size()) {
282       MS_LOG(INTERNAL_EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
283                                  << fg->output()->DebugString(print_deep);
284     }
285     for (size_t i = 0; i < old_size; i++) {
286       (void)mng->Replace(params[i], new_params[i]);
287     }
288   }
289 
CheckNoRecursive(const OptimizerPtr & optimizer)290   void CheckNoRecursive(const OptimizerPtr &optimizer) {
291     // Check if no recursive flag was set in top graph.
292     const auto &resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
293     if (resource == nullptr) {
294       return;
295     }
296     const auto &top_graph = resource->func_graph();
297     if (top_graph == nullptr) {
298       return;
299     }
300     if (top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE)) {
301       no_recursive_ = true;
302     }
303   }
304 
Reset()305   void Reset() {
306     is_checked_ = false;
307     is_recursive_ = false;
308   }
309 
310   // For after block which contains branch call, delete the parameters which is not used.
311   // In most cases, it may be a `Module` or other constant input.
SimplifyAfterParameter(const FuncGraphPtr & fg,const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)312   AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
313                                     const std::vector<AnfNodePtr> &args) const {
314     auto &fg_params = fg->parameters();
315     std::vector<int64_t> used_param_index;
316     auto mng = fg->manager();
317     MS_EXCEPTION_IF_NULL(mng);
318     bool should_simplify = false;
319     for (size_t i = 0; i < fg_params.size(); i++) {
320       if (mng->node_users()[fg_params[i]].size() != 0) {
321         (void)used_param_index.emplace_back(i);
322       } else {
323         MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
324         should_simplify = true;
325       }
326     }
327     if (!should_simplify) {
328       return nullptr;
329     }
330     MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
331     // Clone a new graph and ignore the not used parameters
332     auto new_fg = TransformableClone(fg);
333     auto &new_fg_params = new_fg->parameters();
334     std::vector<AnfNodePtr> new_params;
335     std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
336                    [&new_fg_params](size_t i) { return new_fg_params[i]; });
337     new_fg->set_parameters(new_params);
338 
339     std::vector<AnfNodePtr> node_inputs;
340     node_inputs.push_back(NewValueNode(new_fg));
341     std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
342                    [&args](size_t i) { return args[i]; });
343     auto ret_node = node->func_graph()->NewCNode(node_inputs);
344     ret_node->set_abstract(node->abstract());
345     return ret_node;
346   }
347 
CheckSwitchInputs(const std::vector<AnfNodePtr> & sw_inputs)348   bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) const {
349     // When branch has dead node or poly node, do not perform inline.
350     if (IsDeadNode(sw_inputs[kSwitchTrueBranchIndex]) || IsPolyNode(sw_inputs[kSwitchTrueBranchIndex]) ||
351         IsDeadNode(sw_inputs[kSwitchFalseBranchIndex]) || IsPolyNode(sw_inputs[kSwitchFalseBranchIndex])) {
352       return true;
353     }
354     return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
355   }
356 
357   // This is a try-best algorithm to find a graph which may generate branch call.
358   // It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
GraphHasBranch(const FuncGraphPtr & fg)359   bool GraphHasBranch(const FuncGraphPtr &fg) {
360     if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
361       return graph_branch_cache_[fg];
362     }
363     bool has_branch = false;
364     auto nodes = fg->nodes();
365     for (auto &item : nodes) {
366       if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
367         auto sw_inputs = item->cast<CNodePtr>()->inputs();
368         if (sw_inputs.size() != kIndex4) {
369           MS_LOG(EXCEPTION) << "Switch inputs should be 4";
370         }
371         if (CheckSwitchInputs(sw_inputs)) {
372           has_branch = true;
373           break;
374         }
375       } else if (IsCNodeGraph(item)) {
376         auto cinputs = item->cast<CNodePtr>()->inputs();
377         if (cinputs.size() < 1) {
378           MS_LOG(EXCEPTION) << "Graph call inputs should be greater than 1";
379         }
380         FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
381         bool call_fg_has_branch = GraphHasBranch(call_fg);
382         if (call_fg_has_branch) {
383           has_branch = true;
384           break;
385         }
386       } else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
387         auto cinputs = item->cast<CNodePtr>()->inputs();
388         if (cinputs.size() < kIndex2) {
389           MS_LOG(EXCEPTION) << "Partial call inputs should be greater than 2";
390         }
391         FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
392         if (call_fg == nullptr) {
393           continue;
394         }
395         bool call_fg_has_branch = GraphHasBranch(call_fg);
396         if (call_fg_has_branch) {
397           has_branch = true;
398           break;
399         }
400       }
401     }
402     graph_branch_cache_[fg] = has_branch;
403     return has_branch;
404   }
405 
406   bool is_checked_{false};
407   bool is_recursive_{false};
408   // If the user guarantee that fg has no recursive.
409   bool no_recursive_{false};
410   bool use_move_;
411   std::vector<std::vector<CriterionFuncType>> criterions_;
412   mindspore::HashMap<FuncGraphPtr, bool> graph_branch_cache_;
413 };
414 
IsUniqueUse(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)415 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
416   const auto &users = fg->func_graph_cnodes_index();
417   int64_t n_use = std::accumulate(
418     users.begin(), users.end(), 0,
419     [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
420   return n_use == 1;
421 }
422 
IsTrivial(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)423 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
424   auto n_cnode = fg->nodes().size() - fg->parameters().size();
425   // There is at least one CNode(return, other_node).
426   constexpr size_t least_size = 2;
427   return n_cnode <= least_size;
428 }
429 
IsInside(InlinerBase *,const FuncGraphPtr &,const AnfNodePtr & node)430 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
431   MS_EXCEPTION_IF_NULL(node->func_graph());
432   return node->func_graph()->has_flag("inline_inside");
433 }
434 
IsCore(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)435 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
436 
IsDirectParentCall(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr & node)437 bool IsDirectParentCall(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &node) {
438   bool is_recursive = (inliner->no_recursive() ? false : fg->recursive());
439   if (fg->parent() != nullptr && is_recursive) {
440     if (fg->parent() == node->func_graph() && IsUniqueUse(nullptr, fg, nullptr)) {
441       return true;
442     }
443   }
444   return false;
445 }
446 
IsNotRecursive(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr &)447 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
448   return !inliner->IsRecursive(fg);
449 }
450 
451 class Inliner : public InlinerBase {
452  public:
453   explicit Inliner(bool use_move = true)
454       : InlinerBase(
455           // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
456           {
457             {IsTrivial},
458             {IsInside},
459             {IsCore},
460             {IsNotRecursive},
461             {IsDirectParentCall},
462           },
463           use_move) {}
464 
465   ~Inliner() override = default;
466 };
467 
468 class DirectInliner : public InlinerBase {
469  public:
470   explicit DirectInliner(bool use_move = true)
471       : InlinerBase(
472           // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
473           {
474             {IsForceInline},
475             {IsDirectParentCall},
476           },
477           use_move) {}
478   ~DirectInliner() override = default;
479 };
480 }  // namespace irpass
481 }  // namespace opt
482 }  // namespace mindspore
483 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
484