1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
19
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23 #include <unordered_map>
24
25 #include "frontend/optimizer/irpass.h"
26 #include "frontend/parallel/context.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "ir/func_graph.h"
30 #include "ir/func_graph_cloner.h"
31 #include "ir/tensor.h"
32 #include "frontend/operator/ops.h"
33 #include "abstract/abstract_value.h"
34 #include "utils/utils.h"
35
36 namespace mindspore {
37 namespace opt {
38 namespace irpass {
39 class ReplaceApplicator : public AnfVisitor {
40 public:
operator()41 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
42 if (!IsValueNode<FuncGraph>(node)) {
43 return nullptr;
44 }
45
46 auto fg = GetValueNode<FuncGraphPtr>(node);
47 if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_input()) ||
48 *(fg->switch_layer_input())) {
49 return nullptr;
50 }
51 // Defer inlining in the case of pipeline.
52 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
53 if (fg->stage() != -1 && stage_num > 1) {
54 return nullptr;
55 }
56 // Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
57 if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
58 return nullptr;
59 }
60
61 auto out = fg->output();
62 MS_EXCEPTION_IF_NULL(out);
63 if (!out->isa<CNode>()) {
64 return nullptr;
65 }
66
67 auto &inputs = out->cast<CNodePtr>()->inputs();
68 auto params = fg->parameters();
69
70 // Exclude first elements of inputs which is fn.
71 auto input_size = inputs.size();
72 auto param_size = params.size();
73 if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
74 std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
75 auto inner = inputs[0];
76 if (IsValueNode<Primitive>(inner) ||
77 (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr)) {
78 return inner;
79 }
80 }
81
82 return nullptr;
83 }
84 };
85
86 class InlinerBase;
87 using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
88
89 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
90
91 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
92 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
93 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
94 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
95 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
IsForceInline(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)96 bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
97 return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
98 }
99
100 // {G, Xs}
101 class InlinerBase : public AnfVisitor {
102 public:
103 explicit InlinerBase(std::vector<std::vector<CriterionFuncType>> criterions, bool use_move = true)
use_move_(use_move)104 : use_move_(use_move), criterions_(criterions) {}
105 ~InlinerBase() override = default;
operator()106 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
107 auto cnode = dyn_cast<CNode>(node);
108 if (cnode == nullptr || cnode->size() < 1) {
109 return nullptr;
110 }
111
112 auto &inputs = cnode->inputs();
113 // G
114 auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
115 if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
116 return nullptr;
117 }
118 // Defer inlining in the case of pipeline.
119 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
120 if (fg->stage() != -1 && stage_num > 1) {
121 return nullptr;
122 }
123 // Defer inlining to get the output nodes of the recomputed cell whose output is non-recomputed.
124 if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
125 return nullptr;
126 }
127
128 Reset();
129
130 // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
131 // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
132 // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
133 bool is_match = ApplyCriterions(node, fg);
134 if (!is_match) {
135 return nullptr;
136 }
137
138 std::vector<AnfNodePtr> args;
139 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
140 // Compare size to avoid the case that the function has default value after grad.
141 // for which after renormalize, the function default value will be an input
142 if (fg->parameters().size() != args.size()) {
143 return nullptr;
144 }
145
146 if (IsForceInline(this, fg, node)) {
147 if (IsUniqueUse(nullptr, fg, nullptr)) {
148 return InlineMove(node, fg, args, inputs);
149 } else {
150 return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
151 }
152 }
153
154 if (IsUniqueUse(nullptr, fg, nullptr)) {
155 // For the single used fg, including non-after and after not matched above,
156 // we move the whole fg nodes.
157 auto ret_node = InlineForUniqueUse(node, fg, args, inputs);
158 if (ret_node != nullptr) {
159 return ret_node;
160 }
161 } else {
162 // We don't expand the middle multiple used after block, except the last one.
163 if (GraphHasBranch(fg)) {
164 return nullptr;
165 }
166 // Check if parameters' changed for the first met branch calling.
167 if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
168 auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
169 if (param_simplified_caller != nullptr) {
170 return param_simplified_caller;
171 }
172 }
173 }
174 // Or, just make a clone for not single used fg.
175 MS_LOG(INFO) << "Run InlineClone in inline pass, subgraph number may increase.";
176 return InlineClone(fg, node->func_graph(), args, inputs[0]->scope());
177 }
178
InlineMove(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)179 AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
180 const std::vector<AnfNodePtr> &inputs) {
181 auto mng = fg->manager();
182 MS_EXCEPTION_IF_NULL(mng);
183 ReplaceParams(mng, args, fg);
184 auto out_node = fg->output();
185 mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope());
186 return out_node;
187 }
188
InlineForUniqueUse(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)189 AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
190 const std::vector<AnfNodePtr> &inputs) {
191 if (use_move_) {
192 return InlineMove(node, fg, args, inputs);
193 }
194
195 // The other branch calling the last after block.
196 if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
197 // Check if parameters' changed.
198 auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
199 if (param_simplified_caller != nullptr) {
200 return param_simplified_caller;
201 }
202 }
203 return nullptr;
204 }
205
ApplyCriterions(const AnfNodePtr & node,const FuncGraphPtr & fg)206 bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
207 bool is_match = false;
208 for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
209 is_match = true;
210 for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
211 if (!criterion(this, fg, node)) {
212 is_match = false;
213 break;
214 }
215 }
216 if (is_match) {
217 break;
218 }
219 }
220 return is_match;
221 }
222
ReplaceParams(const FuncGraphManagerPtr & mng,const std::vector<AnfNodePtr> & new_params,const FuncGraphPtr & fg)223 void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
224 const FuncGraphPtr &fg) {
225 auto params = fg->parameters();
226 auto old_size = params.size();
227 if (old_size != new_params.size()) {
228 MS_LOG(EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
229 << fg->output()->DebugString(10);
230 }
231 for (size_t i = 0; i < old_size; i++) {
232 (void)mng->Replace(params[i], new_params[i]);
233 }
234 }
235
IsRecursive(const FuncGraphPtr & fg)236 bool IsRecursive(const FuncGraphPtr &fg) {
237 if (!is_checked_) {
238 is_checked_ = true;
239 is_recursive_ = fg->recursive();
240 }
241 return is_recursive_;
242 }
243
Reset()244 void Reset() {
245 is_checked_ = false;
246 is_recursive_ = false;
247 }
248
249 // For after block which contains branch call, delete the parameters which is not used.
250 // In most cases, it may be a `Module` or other constant input.
SimplifyAfterParameter(const FuncGraphPtr & fg,const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)251 AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
252 const std::vector<AnfNodePtr> &args) {
253 auto &fg_params = fg->parameters();
254 std::vector<int64_t> used_param_index;
255 auto mng = fg->manager();
256 MS_EXCEPTION_IF_NULL(mng);
257 bool should_simplify = false;
258 for (size_t i = 0; i < fg_params.size(); i++) {
259 if (mng->node_users()[fg_params[i]].size() != 0) {
260 used_param_index.emplace_back(i);
261 } else {
262 MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
263 should_simplify = true;
264 }
265 }
266 if (!should_simplify) {
267 return nullptr;
268 }
269 MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
270 // Clone a new graph and ignore the not used parameters
271 auto new_fg = TransformableClone(fg);
272 auto &new_fg_params = new_fg->parameters();
273 std::vector<AnfNodePtr> new_params;
274 std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
275 [&new_fg_params](size_t i) { return new_fg_params[i]; });
276 new_fg->set_parameters(new_params);
277
278 std::vector<AnfNodePtr> node_inputs;
279 node_inputs.push_back(NewValueNode(new_fg));
280 std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
281 [&args](size_t i) { return args[i]; });
282 return node->func_graph()->NewCNode(node_inputs);
283 }
284
CheckSwitchBranchAbstract(const AbstractBasePtr & branch_abstract)285 bool CheckSwitchBranchAbstract(const AbstractBasePtr &branch_abstract) {
286 if (branch_abstract != nullptr && branch_abstract->isa<abstract::AbstractError>()) {
287 auto branch_abstract_value = branch_abstract->GetValueTrack();
288 MS_EXCEPTION_IF_NULL(branch_abstract_value);
289 auto branch_abstract_value_string_imm = branch_abstract_value->cast<StringImmPtr>();
290 if (branch_abstract_value_string_imm != nullptr) {
291 auto branch_abstract_value_string_imm_value = branch_abstract_value_string_imm->value();
292 return branch_abstract_value_string_imm_value == kDeadNodeName ||
293 branch_abstract_value_string_imm_value == kPolyNodeName;
294 }
295 }
296 return false;
297 }
298
CheckSwitchInputs(const std::vector<AnfNodePtr> & sw_inputs)299 bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) {
300 auto true_branch_abstract = sw_inputs[kSwitchTrueKernelGraphIndex]->abstract();
301 auto false_branch_abstract = sw_inputs[kSwitchFalseKernelGraphIndex]->abstract();
302 // When branch has dead node or poly node, do not perform inline.
303 if (CheckSwitchBranchAbstract(true_branch_abstract) || CheckSwitchBranchAbstract(false_branch_abstract)) {
304 return true;
305 }
306 return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
307 }
308
309 // This is a try-best algorithm to find a graph which may generate branch call.
310 // It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
GraphHasBranch(FuncGraphPtr fg)311 bool GraphHasBranch(FuncGraphPtr fg) {
312 if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
313 return graph_branch_cache_[fg];
314 }
315 bool has_branch = false;
316 auto nodes = fg->nodes();
317 for (auto &item : nodes) {
318 if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
319 auto sw_inputs = item->cast<CNodePtr>()->inputs();
320 if (sw_inputs.size() != 4) {
321 MS_LOG(EXCEPTION) << "switch inputs should be 4";
322 }
323 if (CheckSwitchInputs(sw_inputs)) {
324 has_branch = true;
325 break;
326 }
327 } else if (IsCNodeGraph(item)) {
328 auto cinputs = item->cast<CNodePtr>()->inputs();
329 if (cinputs.size() < 1) {
330 MS_LOG(EXCEPTION) << "graph call inputs should greater than 1";
331 }
332 FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
333 bool call_fg_has_branch = GraphHasBranch(call_fg);
334 if (call_fg_has_branch) {
335 has_branch = true;
336 break;
337 }
338 } else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
339 auto cinputs = item->cast<CNodePtr>()->inputs();
340 if (cinputs.size() < 2) {
341 MS_LOG(EXCEPTION) << "partial call inputs should greater than 2";
342 }
343 FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
344 if (call_fg == nullptr) {
345 continue;
346 }
347 bool call_fg_has_branch = GraphHasBranch(call_fg);
348 if (call_fg_has_branch) {
349 has_branch = true;
350 break;
351 }
352 }
353 }
354 graph_branch_cache_[fg] = has_branch;
355 return has_branch;
356 }
357
358 private:
359 bool is_checked_{false}, is_recursive_{false};
360 bool use_move_;
361 std::vector<std::vector<CriterionFuncType>> criterions_;
362 std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
363 };
364
IsUniqueUse(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)365 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
366 const auto &users = fg->func_graph_cnodes_index();
367 int64_t n_use = std::accumulate(
368 users.begin(), users.end(), 0,
369 [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
370 return n_use == 1;
371 }
372
IsTrivial(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)373 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
374 auto n_cnode = fg->nodes().size() - fg->parameters().size();
375 // There is at least one CNode(return, other_node).
376 return n_cnode <= 2;
377 }
378
IsInside(InlinerBase *,const FuncGraphPtr &,const AnfNodePtr & node)379 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
380 MS_EXCEPTION_IF_NULL(node->func_graph());
381 return node->func_graph()->has_flag("inline_inside");
382 }
383
IsCore(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)384 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
385
IsDirectParentCall(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr & node)386 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node) {
387 bool unique_use = IsUniqueUse(nullptr, fg, nullptr);
388 bool is_recursive = fg->recursive();
389 if (fg->parent() != nullptr && is_recursive) {
390 if (fg->parent() == node->func_graph() && unique_use) {
391 return true;
392 }
393 }
394 return false;
395 }
396
IsNotRecursive(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr &)397 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
398 return !inliner->IsRecursive(fg);
399 }
400
401 class Inliner : public InlinerBase {
402 public:
403 explicit Inliner(bool use_move = true)
404 : InlinerBase(
405 // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
406 {
407 {IsTrivial},
408 {IsInside},
409 {IsCore},
410 {IsNotRecursive},
411 {IsDirectParentCall},
412 },
413 use_move) {}
414
415 ~Inliner() override = default;
416 };
417
418 class DirectInliner : public InlinerBase {
419 public:
420 explicit DirectInliner(bool use_move = true)
421 : InlinerBase(
422 // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
423 {
424 {IsForceInline},
425 {IsDirectParentCall},
426 },
427 use_move) {}
428 ~DirectInliner() override = default;
429 };
430 } // namespace irpass
431 } // namespace opt
432 } // namespace mindspore
433 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
434