• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
19 
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23 #include <unordered_map>
24 
25 #include "frontend/optimizer/irpass.h"
26 #include "frontend/parallel/context.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "ir/func_graph.h"
30 #include "ir/func_graph_cloner.h"
31 #include "ir/tensor.h"
32 #include "frontend/operator/ops.h"
33 #include "abstract/abstract_value.h"
34 #include "utils/utils.h"
35 
36 namespace mindspore {
37 namespace opt {
38 namespace irpass {
39 class ReplaceApplicator : public AnfVisitor {
40  public:
operator()41   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
42     if (!IsValueNode<FuncGraph>(node)) {
43       return nullptr;
44     }
45 
46     auto fg = GetValueNode<FuncGraphPtr>(node);
47     if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_input()) ||
48         *(fg->switch_layer_input())) {
49       return nullptr;
50     }
51     // Defer inlining in the case of pipeline.
52     auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
53     if (fg->stage() != -1 && stage_num > 1) {
54       return nullptr;
55     }
56     // Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
57     if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
58       return nullptr;
59     }
60 
61     auto out = fg->output();
62     MS_EXCEPTION_IF_NULL(out);
63     if (!out->isa<CNode>()) {
64       return nullptr;
65     }
66 
67     auto &inputs = out->cast<CNodePtr>()->inputs();
68     auto params = fg->parameters();
69 
70     // Exclude first elements of inputs which is fn.
71     auto input_size = inputs.size();
72     auto param_size = params.size();
73     if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
74                                                  std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
75       auto inner = inputs[0];
76       if (IsValueNode<Primitive>(inner) ||
77           (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
78         return inner;
79       }
80     }
81 
82     return nullptr;
83   }
84 };
85 
86 class InlinerBase;
87 using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
88 
89 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
90 
91 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
92 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
93 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
94 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
95 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
IsForceInline(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)96 bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
97   return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
98 }
99 
100 // {G, Xs}
101 class InlinerBase : public AnfVisitor {
102  public:
103   explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true)
use_move_(use_move)104       : use_move_(use_move), criterions_(criterions) {}
105   ~InlinerBase() override = default;
operator()106   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
107     auto cnode = dyn_cast<CNode>(node);
108     if (cnode == nullptr || cnode->size() < 1) {
109       return nullptr;
110     }
111 
112     auto &inputs = cnode->inputs();
113     // G
114     auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
115     if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
116       return nullptr;
117     }
118     // Defer inlining in the case of pipeline.
119     auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
120     if (fg->stage() != -1 && stage_num > 1) {
121       return nullptr;
122     }
123     // Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
124     if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
125       return nullptr;
126     }
127 
128     Reset();
129 
130     // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
131     // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
132     // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
133     bool is_match = ApplyCriterions(node, fg);
134     if (!is_match) {
135       return nullptr;
136     }
137 
138     std::vector<AnfNodePtr> args;
139     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
140     // Compare size to avoid the case that the function has default value after grad.
141     // for which after renormalize, the function default value will be an input
142     if (fg->parameters().size() != args.size()) {
143       return nullptr;
144     }
145 
146     if (IsForceInline(this, fg, node)) {
147       if (IsUniqueUse(nullptr, fg, nullptr)) {
148         return InlineMove(node, fg, args, inputs);
149       } else {
150         return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
151       }
152     }
153 
154     if (IsUniqueUse(nullptr, fg, nullptr)) {
155       // For the single used fg, including non-after and after not matched above,
156       // we move the whole fg nodes.
157       auto ret_node = InlineForUniqueUse(node, fg, args, inputs);
158       if (ret_node != nullptr) {
159         return ret_node;
160       }
161     } else {
162       // We don't expand the middle multiple used after block, except the last one.
163       if (GraphHasBranch(fg)) {
164         return nullptr;
165       }
166       // Check if parameters' changed for the first met branch calling.
167       if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
168         auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
169         if (param_simplified_caller != nullptr) {
170           return param_simplified_caller;
171         }
172       }
173     }
174     // Or, just make a clone for not single used fg.
175     MS_LOG(INFO) << "Run InlineClone in inline pass, subgraph number may increase.";
176     return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
177   }
178 
InlineMove(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)179   AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
180                         const std::vector<AnfNodePtr> &inputs) {
181     auto mng = fg->manager();
182     MS_EXCEPTION_IF_NULL(mng);
183     ReplaceParams(mng, args, fg);
184     auto out_node = fg->output();
185     mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
186     return out_node;
187   }
188 
InlineForUniqueUse(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)189   AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
190                                 const std::vector<AnfNodePtr> &inputs) {
191     if (use_move_) {
192       return InlineMove(node, fg, args, inputs);
193     }
194 
195     // The other branch calling the last after block.
196     if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
197       // Check if parameters' changed.
198       auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
199       if (param_simplified_caller != nullptr) {
200         return param_simplified_caller;
201       }
202     }
203     return nullptr;
204   }
205 
ApplyCriterions(const AnfNodePtr & node,const FuncGraphPtr & fg)206   bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
207     bool is_match = false;
208     for (auto &criterions : criterions_) {  // Each 'criterion group' in criterions_.
209       is_match = true;
210       for (auto &criterion : criterions) {  // Each criterion in 'criterion group'.
211         if (!criterion(this, fg, node)) {
212           is_match = false;
213           break;
214         }
215       }
216       if (is_match) {
217         break;
218       }
219     }
220     return is_match;
221   }
222 
ReplaceParams(const FuncGraphManagerPtr & mng,const std::vector<AnfNodePtr> & new_params,const FuncGraphPtr & fg)223   void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
224                      const FuncGraphPtr &fg) {
225     auto params = fg->parameters();
226     auto old_size = params.size();
227     if (old_size != new_params.size()) {
228       MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
229                         << fg->output()->DebugString(10);
230     }
231     for (size_t i = 0; i < old_size; i++) {
232       (void)mng->Replace(params[i], new_params[i]);
233     }
234   }
235 
IsRecursive(const FuncGraphPtr & fg)236   bool IsRecursive(const FuncGraphPtr &fg) {
237     if (!is_checked_) {
238       is_checked_ = true;
239       is_recursive_ = fg->recursive();
240     }
241     return is_recursive_;
242   }
243 
Reset()244   void Reset() {
245     is_checked_ = false;
246     is_recursive_ = false;
247   }
248 
249   // For after block which contains branch call, delete the parameters which is not used.
250   // In most cases, it may be a `Module` or other constant input.
SimplifyAfterParameter(const FuncGraphPtr & fg,const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)251   AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
252                                     const std::vector<AnfNodePtr> &args) {
253     auto &fg_params = fg->parameters();
254     std::vector<int64_t> used_param_index;
255     auto mng = fg->manager();
256     MS_EXCEPTION_IF_NULL(mng);
257     bool should_simplify = false;
258     for (size_t i = 0; i < fg_params.size(); i++) {
259       if (mng->node_users()[fg_params[i]].size() != 0) {
260         used_param_index.emplace_back(i);
261       } else {
262         MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
263         should_simplify = true;
264       }
265     }
266     if (!should_simplify) {
267       return nullptr;
268     }
269     MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
270     // Clone a new graph and ignore the not used parameters
271     auto new_fg = TransformableClone(fg);
272     auto &new_fg_params = new_fg->parameters();
273     std::vector<AnfNodePtr> new_params;
274     std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
275                    [&new_fg_params](size_t i) { return new_fg_params[i]; });
276     new_fg->set_parameters(new_params);
277 
278     std::vector<AnfNodePtr> node_inputs;
279     node_inputs.push_back(NewValueNode(new_fg));
280     std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
281                    [&args](size_t i) { return args[i]; });
282     return node->func_graph()->NewCNode(node_inputs);
283   }
284 
CheckSwitchBranchAbstract(const AbstractBasePtr & branch_abstract)285   bool CheckSwitchBranchAbstract(const AbstractBasePtr &branch_abstract) {
286     if (branch_abstract != nullptr && branch_abstract->isa<abstract::AbstractError>()) {
287       auto branch_abstract_value = branch_abstract->GetValueTrack();
288       MS_EXCEPTION_IF_NULL(branch_abstract_value);
289       auto branch_abstract_value_string_imm = branch_abstract_value->cast<StringImmPtr>();
290       if (branch_abstract_value_string_imm != nullptr) {
291         auto branch_abstract_value_string_imm_value = branch_abstract_value_string_imm->value();
292         return branch_abstract_value_string_imm_value == kDeadNodeName ||
293                branch_abstract_value_string_imm_value == kPolyNodeName;
294       }
295     }
296     return false;
297   }
298 
CheckSwitchInputs(const std::vector<AnfNodePtr> & sw_inputs)299   bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
300     auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
301     auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
302     // When branch has dead node or poly node, do not perform inline.
303     if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
304       return true;
305     }
306     return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
307   }
308 
309   // This is a try-best algorithm to find a graph which may generate branch call.
310   // It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
GraphHasBranch(FuncGraphPtr fg)311   bool GraphHasBranch(FuncGraphPtr fg) {
312     if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
313       return graph_branch_cache_[fg];
314     }
315     bool has_branch = false;
316     auto nodes = fg->nodes();
317     for (auto &item : nodes) {
318       if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
319         auto sw_inputs = item->cast<CNodePtr>()->inputs();
320         if (sw_inputs.size() != 4) {
321           MS_LOG(EXCEPTION) << "switch inputs should be 4";
322         }
323         if (CheckSwitchInputs(sw_inputs)) {
324           has_branch = true;
325           break;
326         }
327       } else if (IsCNodeGraph(item)) {
328         auto cinputs = item->cast<CNodePtr>()->inputs();
329         if (cinputs.size() < 1) {
330           MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
331         }
332         FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
333         bool call_fg_has_branch = GraphHasBranch(call_fg);
334         if (call_fg_has_branch) {
335           has_branch = true;
336           break;
337         }
338       } else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
339         auto cinputs = item->cast<CNodePtr>()->inputs();
340         if (cinputs.size() < 2) {
341           MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
342         }
343         FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
344         if (call_fg == nullptr) {
345           continue;
346         }
347         bool call_fg_has_branch = GraphHasBranch(call_fg);
348         if (call_fg_has_branch) {
349           has_branch = true;
350           break;
351         }
352       }
353     }
354     graph_branch_cache_[fg] = has_branch;
355     return has_branch;
356   }
357 
358  private:
359   bool is_checked_{false}, is_recursive_{false};
360   bool use_move_;
361   std::vector<std::vector<CriterionFuncType>> criterions_;
362   std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
363 };
364 
IsUniqueUse(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)365 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
366   const auto &users = fg->func_graph_cnodes_index();
367   int64_t n_use = std::accumulate(
368     users.begin(), users.end(), 0,
369     [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
370   return n_use == 1;
371 }
372 
IsTrivial(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)373 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
374   auto n_cnode = fg->nodes().size() - fg->parameters().size();
375   // There is at least one CNode(return, other_node).
376   return n_cnode <= 2;
377 }
378 
IsInside(InlinerBase *,const FuncGraphPtr &,const AnfNodePtr & node)379 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
380   MS_EXCEPTION_IF_NULL(node->func_graph());
381   return node->func_graph()->has_flag("inline_inside");
382 }
383 
IsCore(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)384 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
385 
IsDirectParentCall(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr & node)386 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) {
387   bool unique_use = IsUniqueUse(nullptr, fg, nullptr);
388   bool is_recursive = fg->recursive();
389   if (fg->parent() != nullptr && is_recursive) {
390     if (fg->parent() == node->func_graph() && unique_use) {
391       return true;
392     }
393   }
394   return false;
395 }
396 
IsNotRecursive(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr &)397 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
398   return !inliner->IsRecursive(fg);
399 }
400 
401 class Inliner : public InlinerBase {
402  public:
403   explicit Inliner(bool use_move = true)
404       : InlinerBase(
405           // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
406           {
407             {IsTrivial},
408             {IsInside},
409             {IsCore},
410             {IsNotRecursive},
411             {IsDirectParentCall},
412           },
413           use_move) {}
414 
415   ~Inliner() override = default;
416 };
417 
418 class DirectInliner : public InlinerBase {
419  public:
420   explicit DirectInliner(bool use_move = true)
421       : InlinerBase(
422           // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
423           {
424             {IsForceInline},
425             {IsDirectParentCall},
426           },
427           use_move) {}
428   ~DirectInliner() override = default;
429 };
430 }  // namespace irpass
431 }  // namespace opt
432 }  // namespace mindspore
433 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
434