1 /**
2 * Copyright 2020-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
19
20 #include <vector>
21 #include <utility>
22 #include <algorithm>
23
24 #include "utils/hash_map.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "include/common/utils/parallel_context.h"
28 #include "frontend/optimizer/optimizer.h"
29 #include "frontend/optimizer/anf_visitor.h"
30 #include "ir/func_graph.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/tensor.h"
33 #include "frontend/operator/ops.h"
34 #include "abstract/abstract_value.h"
35 #include "include/common/utils/utils.h"
36 #include "pipeline/jit/ps/pipeline.h"
37
38 namespace mindspore {
39 namespace opt {
40 namespace irpass {
41 class ReplaceApplicator : public AnfVisitor {
42 public:
operator()43 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
44 if (!IsValueNode<FuncGraph>(node)) {
45 return nullptr;
46 }
47 auto fg = GetValueNode<FuncGraphPtr>(node);
48 if (NoInline(fg)) {
49 return nullptr;
50 }
51
52 auto out = fg->output();
53 MS_EXCEPTION_IF_NULL(out);
54 if (!out->isa<CNode>()) {
55 return nullptr;
56 }
57
58 auto &inputs = out->cast<CNodePtr>()->inputs();
59 auto params = fg->parameters();
60
61 // Exclude first elements of inputs which is fn.
62 auto input_size = inputs.size();
63 auto param_size = params.size();
64 if ((input_size == 1 && param_size == 0) || (input_size > 1 && (input_size - 1) == param_size &&
65 std::equal(inputs.begin() + 1, inputs.end(), params.begin()))) {
66 auto inner = inputs[0];
67 if (IsValueNode<Primitive>(inner)) {
68 return inner;
69 }
70 if (IsValueNode<FuncGraph>(inner) && GetValueNode<FuncGraphPtr>(inner)->parent() == nullptr) {
71 const auto &inner_fg = GetValueNode<FuncGraphPtr>(inner);
72 MS_EXCEPTION_IF_NULL(inner_fg);
73 bool is_recursive = (inner_fg->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE) ? false : inner_fg->recursive());
74 if (is_recursive) {
75 return nullptr;
76 }
77 return inner;
78 }
79 }
80
81 return nullptr;
82 }
83
NoInline(const FuncGraphPtr & fg)84 bool NoInline(const FuncGraphPtr &fg) const {
85 if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
86 fg->stub() || *(fg->indirect())) {
87 return true;
88 }
89 // Defer inlining in the case of pipeline.
90 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
91 if (fg->stage() != -1 && stage_num > 1) {
92 return true;
93 }
94 // Defer inlining for:
95 // 1. The func_graph which is set recomputed.
96 // 2. The k graph whose primal is set non-recomputed when enable graph reuse.
97 auto context = MsContext::GetInstance();
98 MS_EXCEPTION_IF_NULL(context);
99 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
100 return fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
101 (cell_reuse &&
102 (fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)));
103 }
104 };
105
106 class InlinerBase;
107 using CriterionFuncType = std::function<bool(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &)>;
108
109 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
110
111 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
112 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node);
113 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &);
114 bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node);
115 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &);
IsForceInline(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)116 bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
117 return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE);
118 }
119
120 // {G, Xs}
121 class InlinerBase : public AnfVisitor {
122 public:
123 explicit InlinerBase(const std::vector<std::vector<CriterionFuncType>> &criterions, bool use_move = true)
use_move_(use_move)124 : use_move_(use_move), criterions_(criterions) {}
125 ~InlinerBase() override = default;
operator()126 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
127 auto cnode = dyn_cast<CNode>(node);
128 if (cnode == nullptr || cnode->size() < 1) {
129 return nullptr;
130 }
131
132 // Check if no recursive flag was set in top graph.
133 CheckNoRecursive(optimizer);
134
135 auto &inputs = cnode->inputs();
136 // G
137 auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
138 if (!CheckFlag(fg)) {
139 return nullptr;
140 }
141
142 Reset();
143
144 // 'criterions_': {criterion_group_1:{criterion1, criterion2, ...}, criterion_group_2:{...}, ...}
145 // All the criterions of 'criterion group' are true would set 'criterion group' as 'true'. As [AND].
146 // Anyone of 'criterion group' in 'criterions_' is 'true' would be matched. As [OR].
147 bool is_match = ApplyCriterions(node, fg);
148 if (!is_match) {
149 return nullptr;
150 }
151
152 std::vector<AnfNodePtr> args;
153 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
154 // Compare size to avoid the case that the function has default value after grad.
155 // for which after renormalize, the function default value will be an input
156 if (fg->parameters().size() != args.size()) {
157 return nullptr;
158 }
159
160 if (IsForceInline(this, fg, node)) {
161 if (IsUniqueUse(nullptr, fg, nullptr)) {
162 return InlineMove(node, fg, args, inputs);
163 }
164 return InlineClone(fg, node->func_graph(), args, cnode);
165 }
166
167 if (IsUniqueUse(nullptr, fg, nullptr)) {
168 // For the single used fg, including non-after and after not matched above,
169 // we move the whole fg nodes.
170 auto res_node = InlineForUniqueUse(node, fg, args, inputs);
171 if (res_node != nullptr) {
172 return res_node;
173 }
174 } else {
175 // We don't expand the middle multiple used after block, except the last one.
176 if (GraphHasBranch(fg)) {
177 return nullptr;
178 }
179 // Check if parameters' changed for the first met branch calling.
180 if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
181 auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
182 if (param_simplified_caller != nullptr) {
183 return param_simplified_caller;
184 }
185 }
186 }
187 // Or, just make a clone for not single used fg.
188 auto res = InlineClone(fg, node->func_graph(), args, cnode);
189 return res;
190 }
191
CheckFlag(const FuncGraphPtr & fg)192 bool CheckFlag(const FuncGraphPtr &fg) const {
193 if (fg == nullptr || fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) || fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) ||
194 fg->stub()) {
195 return false;
196 }
197 // Defer inlining in the case of pipeline.
198 auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
199 if (fg->stage() != -1 && stage_num > 1) {
200 return false;
201 }
202 // Defer inlining for:
203 // 1. The func_graph which is set recomputed.
204 // 2. The k graph whose primal is set non-recomputed when enable graph reuse.
205 auto context = MsContext::GetInstance();
206 MS_EXCEPTION_IF_NULL(context);
207 const auto cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
208 if (fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
209 (cell_reuse &&
210 (fg->has_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)))) {
211 return false;
212 }
213 return true;
214 }
215
IsRecursive(const FuncGraphPtr & fg)216 bool IsRecursive(const FuncGraphPtr &fg) {
217 // The user guarantees that fg has no recursive.
218 if (no_recursive_) {
219 return false;
220 }
221
222 if (!is_checked_) {
223 is_checked_ = true;
224 is_recursive_ = (fg->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE) ? false : fg->recursive());
225 }
226 return is_recursive_;
227 }
228
no_recursive()229 bool no_recursive() const { return no_recursive_; }
230
231 private:
InlineMove(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)232 AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
233 const std::vector<AnfNodePtr> &inputs) const {
234 auto mng = fg->manager();
235 MS_EXCEPTION_IF_NULL(mng);
236 ReplaceParams(mng, args, fg);
237 auto out_node = fg->output();
238 mng->MoveAllCNodeDropGraph(fg, node->func_graph(), node, inputs[0]->scope(), true);
239 return out_node;
240 }
241
InlineForUniqueUse(const AnfNodePtr & node,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & args,const std::vector<AnfNodePtr> & inputs)242 AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args,
243 const std::vector<AnfNodePtr> &inputs) const {
244 if (use_move_) {
245 return InlineMove(node, fg, args, inputs);
246 }
247
248 // The other branch calling the last after block.
249 if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
250 // Check if parameters' changed.
251 auto param_simplified_caller = SimplifyAfterParameter(fg, node, args);
252 if (param_simplified_caller != nullptr) {
253 return param_simplified_caller;
254 }
255 }
256 return nullptr;
257 }
258
ApplyCriterions(const AnfNodePtr & node,const FuncGraphPtr & fg)259 bool ApplyCriterions(const AnfNodePtr &node, const FuncGraphPtr &fg) {
260 bool is_match = false;
261 for (auto &criterions : criterions_) { // Each 'criterion group' in criterions_.
262 is_match = true;
263 for (auto &criterion : criterions) { // Each criterion in 'criterion group'.
264 if (!criterion(this, fg, node)) {
265 is_match = false;
266 break;
267 }
268 }
269 if (is_match) {
270 break;
271 }
272 }
273 return is_match;
274 }
275
ReplaceParams(const FuncGraphManagerPtr & mng,const std::vector<AnfNodePtr> & new_params,const FuncGraphPtr & fg)276 void ReplaceParams(const FuncGraphManagerPtr &mng, const std::vector<AnfNodePtr> &new_params,
277 const FuncGraphPtr &fg) const {
278 auto params = fg->parameters();
279 auto old_size = params.size();
280 constexpr auto print_deep = 10;
281 if (old_size != new_params.size()) {
282 MS_LOG(INTERNAL_EXCEPTION) << "Parameter size not match." << old_size << " new " << new_params.size()
283 << fg->output()->DebugString(print_deep);
284 }
285 for (size_t i = 0; i < old_size; i++) {
286 (void)mng->Replace(params[i], new_params[i]);
287 }
288 }
289
CheckNoRecursive(const OptimizerPtr & optimizer)290 void CheckNoRecursive(const OptimizerPtr &optimizer) {
291 // Check if no recursive flag was set in top graph.
292 const auto &resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
293 if (resource == nullptr) {
294 return;
295 }
296 const auto &top_graph = resource->func_graph();
297 if (top_graph == nullptr) {
298 return;
299 }
300 if (top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE)) {
301 no_recursive_ = true;
302 }
303 }
304
Reset()305 void Reset() {
306 is_checked_ = false;
307 is_recursive_ = false;
308 }
309
310 // For after block which contains branch call, delete the parameters which is not used.
311 // In most cases, it may be a `Module` or other constant input.
SimplifyAfterParameter(const FuncGraphPtr & fg,const AnfNodePtr & node,const std::vector<AnfNodePtr> & args)312 AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node,
313 const std::vector<AnfNodePtr> &args) const {
314 auto &fg_params = fg->parameters();
315 std::vector<int64_t> used_param_index;
316 auto mng = fg->manager();
317 MS_EXCEPTION_IF_NULL(mng);
318 bool should_simplify = false;
319 for (size_t i = 0; i < fg_params.size(); i++) {
320 if (mng->node_users()[fg_params[i]].size() != 0) {
321 (void)used_param_index.emplace_back(i);
322 } else {
323 MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString();
324 should_simplify = true;
325 }
326 }
327 if (!should_simplify) {
328 return nullptr;
329 }
330 MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
331 // Clone a new graph and ignore the not used parameters
332 auto new_fg = TransformableClone(fg);
333 auto &new_fg_params = new_fg->parameters();
334 std::vector<AnfNodePtr> new_params;
335 std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
336 [&new_fg_params](size_t i) { return new_fg_params[i]; });
337 new_fg->set_parameters(new_params);
338
339 std::vector<AnfNodePtr> node_inputs;
340 node_inputs.push_back(NewValueNode(new_fg));
341 std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
342 [&args](size_t i) { return args[i]; });
343 auto ret_node = node->func_graph()->NewCNode(node_inputs);
344 ret_node->set_abstract(node->abstract());
345 return ret_node;
346 }
347
CheckSwitchInputs(const std::vector<AnfNodePtr> & sw_inputs)348 bool CheckSwitchInputs(const std::vector<AnfNodePtr> &sw_inputs) const {
349 // When branch has dead node or poly node, do not perform inline.
350 if (IsDeadNode(sw_inputs[kSwitchTrueBranchIndex]) || IsPolyNode(sw_inputs[kSwitchTrueBranchIndex]) ||
351 IsDeadNode(sw_inputs[kSwitchFalseBranchIndex]) || IsPolyNode(sw_inputs[kSwitchFalseBranchIndex])) {
352 return true;
353 }
354 return !sw_inputs[1]->isa<ValueNode>() || IsValueNode<tensor::Tensor>(sw_inputs[1]);
355 }
356
357 // This is a try-best algorithm to find a graph which may generate branch call.
358 // It does not handle high-order function call. For high-orderer call branch, it still may be inlined.
GraphHasBranch(const FuncGraphPtr & fg)359 bool GraphHasBranch(const FuncGraphPtr &fg) {
360 if (graph_branch_cache_.find(fg) != graph_branch_cache_.end()) {
361 return graph_branch_cache_[fg];
362 }
363 bool has_branch = false;
364 auto nodes = fg->nodes();
365 for (auto &item : nodes) {
366 if (IsPrimitiveCNode(item, prim::kPrimSwitch)) {
367 auto sw_inputs = item->cast<CNodePtr>()->inputs();
368 if (sw_inputs.size() != kIndex4) {
369 MS_LOG(EXCEPTION) << "Switch inputs should be 4";
370 }
371 if (CheckSwitchInputs(sw_inputs)) {
372 has_branch = true;
373 break;
374 }
375 } else if (IsCNodeGraph(item)) {
376 auto cinputs = item->cast<CNodePtr>()->inputs();
377 if (cinputs.size() < 1) {
378 MS_LOG(EXCEPTION) << "Graph call inputs should be greater than 1";
379 }
380 FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[0]);
381 bool call_fg_has_branch = GraphHasBranch(call_fg);
382 if (call_fg_has_branch) {
383 has_branch = true;
384 break;
385 }
386 } else if (IsPrimitiveCNode(item, prim::kPrimPartial)) {
387 auto cinputs = item->cast<CNodePtr>()->inputs();
388 if (cinputs.size() < kIndex2) {
389 MS_LOG(EXCEPTION) << "Partial call inputs should be greater than 2";
390 }
391 FuncGraphPtr call_fg = GetValueNode<FuncGraphPtr>(cinputs[1]);
392 if (call_fg == nullptr) {
393 continue;
394 }
395 bool call_fg_has_branch = GraphHasBranch(call_fg);
396 if (call_fg_has_branch) {
397 has_branch = true;
398 break;
399 }
400 }
401 }
402 graph_branch_cache_[fg] = has_branch;
403 return has_branch;
404 }
405
406 bool is_checked_{false};
407 bool is_recursive_{false};
408 // If the user guarantee that fg has no recursive.
409 bool no_recursive_{false};
410 bool use_move_;
411 std::vector<std::vector<CriterionFuncType>> criterions_;
412 mindspore::HashMap<FuncGraphPtr, bool> graph_branch_cache_;
413 };
414
IsUniqueUse(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)415 bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
416 const auto &users = fg->func_graph_cnodes_index();
417 int64_t n_use = std::accumulate(
418 users.begin(), users.end(), 0,
419 [](int64_t sum, const std::pair<const CNodeIndexPairPtr, int64_t> &item) { return sum + item.second; });
420 return n_use == 1;
421 }
422
IsTrivial(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)423 bool IsTrivial(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {
424 auto n_cnode = fg->nodes().size() - fg->parameters().size();
425 // There is at least one CNode(return, other_node).
426 constexpr size_t least_size = 2;
427 return n_cnode <= least_size;
428 }
429
IsInside(InlinerBase *,const FuncGraphPtr &,const AnfNodePtr & node)430 bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node) {
431 MS_EXCEPTION_IF_NULL(node->func_graph());
432 return node->func_graph()->has_flag("inline_inside");
433 }
434
IsCore(InlinerBase *,const FuncGraphPtr & fg,const AnfNodePtr &)435 bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { return fg->has_flag("core"); }
436
IsDirectParentCall(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr & node)437 bool IsDirectParentCall(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &node) {
438 bool is_recursive = (inliner->no_recursive() ? false : fg->recursive());
439 if (fg->parent() != nullptr && is_recursive) {
440 if (fg->parent() == node->func_graph() && IsUniqueUse(nullptr, fg, nullptr)) {
441 return true;
442 }
443 }
444 return false;
445 }
446
IsNotRecursive(InlinerBase * inliner,const FuncGraphPtr & fg,const AnfNodePtr &)447 bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &) {
448 return !inliner->IsRecursive(fg);
449 }
450
451 class Inliner : public InlinerBase {
452 public:
453 explicit Inliner(bool use_move = true)
454 : InlinerBase(
455 // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
456 {
457 {IsTrivial},
458 {IsInside},
459 {IsCore},
460 {IsNotRecursive},
461 {IsDirectParentCall},
462 },
463 use_move) {}
464
465 ~Inliner() override = default;
466 };
467
468 class DirectInliner : public InlinerBase {
469 public:
470 explicit DirectInliner(bool use_move = true)
471 : InlinerBase(
472 // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}.
473 {
474 {IsForceInline},
475 {IsDirectParentCall},
476 },
477 use_move) {}
478 ~DirectInliner() override = default;
479 };
480 } // namespace irpass
481 } // namespace opt
482 } // namespace mindspore
483 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_INLINE_H_
484