1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
19 #include <vector>
20 #include <utility>
21 #include <memory>
22
23 #include "utils/hash_set.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph.h"
31 #include "frontend/operator/ops.h"
32
33 namespace mindspore {
34 namespace opt {
35 namespace irpass {
CheckSwitchCallValid(const CNodePtr & switch_call)36 static inline void CheckSwitchCallValid(const CNodePtr &switch_call) {
37 if (switch_call->size() > 1) {
38 // Means call switch(arg1, ...) has args.
39 constexpr auto recursive_count = 2;
40 MS_LOG(INTERNAL_EXCEPTION) << "After switch_call_monad_eliminater pass, the call switch node should not has args."
41 << " The call_switch_cnode is: " << switch_call->DebugString(recursive_count);
42 }
43 }
44
GetCallers(const FuncGraphPtr & fg)45 static inline std::vector<CNodePtr> GetCallers(const FuncGraphPtr &fg) {
46 MS_EXCEPTION_IF_NULL(fg);
47 const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
48 std::vector<CNodePtr> caller_cnodes = {};
49 // Find all caller of fg.
50 auto manager = fg->manager();
51 MS_EXCEPTION_IF_NULL(manager);
52 auto &node_users = manager->node_users();
53 for (const auto &it : fg_caller_and_indexes) {
54 const auto &fg_caller_and_index = it.first;
55 auto caller_cnode = fg_caller_and_index->first;
56 auto index = fg_caller_and_index->second;
57 // If index != 0, the caller is a indirect caller, can't erase the parameter of graph.
58 // Because in this situation ValueNode<FuncGraph> is a input of Return or of MakeTuple.
59 MS_LOG(DEBUG) << "index: " << index;
60 // Process has partial func_graph with Primitive
61 // %1 = Partial(func_graph, arg1, arg2, ...)
62 if (index == 1 && IsPrimitiveCNode(caller_cnode, prim::kPrimPartial)) {
63 auto iter = node_users.find(caller_cnode);
64 for (auto &user : iter->second) {
65 auto &user_node = user.first;
66 auto user_cnode = user_node->cast<CNodePtr>();
67 // Check user of partial (switch), the numbers of args should be 0.
68 if (IsPrimitiveCNode(user_cnode, prim::kPrimSwitch)) {
69 // Call switch()
70 auto call_switchs = node_users[user_cnode];
71 for (auto call_switch_iter : call_switchs) {
72 CheckSwitchCallValid(call_switch_iter.first->cast<CNodePtr>());
73 }
74 if (std::find(caller_cnodes.begin(), caller_cnodes.end(), caller_cnode) == caller_cnodes.end()) {
75 (void)caller_cnodes.emplace_back(caller_cnode->cast<CNodePtr>());
76 }
77 }
78 }
79 } else if (index != 0) {
80 return {};
81 } else {
82 // Process call func_graph: %1 = func_graph(arg1, arg2, ...)
83 (void)caller_cnodes.emplace_back(caller_cnode->cast<CNodePtr>());
84 }
85 }
86 return caller_cnodes;
87 }
88
SearchFuncGraphCallers(const FuncGraphPtr & func_graph,bool eliminate_only_returned_parameter)89 static inline std::pair<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(
90 const FuncGraphPtr &func_graph, bool eliminate_only_returned_parameter) {
91 for (const auto &fg : func_graph->func_graphs_used_total()) {
92 if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
93 continue;
94 }
95 const auto ¶meters = fg->parameters();
96 MS_EXCEPTION_IF_NULL(fg->manager());
97 const auto &manager_node_users = fg->manager()->node_users();
98 // Check if no user parameter or only one user in output tuple.
99 bool exist_param_unused =
100 std::any_of(parameters.begin(), parameters.end(),
101 [&manager_node_users, &fg, eliminate_only_returned_parameter](const AnfNodePtr ¶meter) {
102 const auto &node_users_it = manager_node_users.find(parameter);
103 // No user parameter.
104 if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
105 return true;
106 }
107 // We will check the tuple output, if only one user.
108 if (eliminate_only_returned_parameter && fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) &&
109 node_users_it->second.size() == 1) {
110 auto user = node_users_it->second.begin()->first;
111 // The parameter only used as returned MakeTuple's element.
112 if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && fg->output() == user) {
113 return true;
114 }
115 }
116 return false;
117 });
118 if (exist_param_unused) {
119 const auto &callers = GetCallers(fg);
120 if (!callers.empty()) {
121 return {fg, callers};
122 }
123 }
124 }
125 return {nullptr, {}};
126 }
127
EraseUnusedParameters(const FuncGraphPtr & fg,bool eliminate_only_returned_parameter)128 static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, size_t>> EraseUnusedParameters(
129 const FuncGraphPtr &fg, bool eliminate_only_returned_parameter) {
130 MS_EXCEPTION_IF_NULL(fg);
131 const FuncGraphManagerPtr &manager = fg->manager();
132 MS_EXCEPTION_IF_NULL(manager);
133 const auto &manager_node_users = manager->node_users();
134 const auto ¶meters = fg->parameters();
135 mindspore::HashSet<size_t> unused_parameter_indexes;
136 mindspore::HashMap<size_t, size_t> only_return_parameter_indexes;
137 // Traverse to find all unused parameters.
138 size_t index = 0;
139 for (const auto ¶meter : parameters) {
140 const auto &node_users_it = manager_node_users.find(parameter);
141 if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
142 (void)unused_parameter_indexes.emplace(index);
143 } else if (eliminate_only_returned_parameter && fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) &&
144 node_users_it->second.size() == 1) {
145 auto user = node_users_it->second.begin()->first;
146 auto pos = node_users_it->second.begin()->second;
147 // The parameter only used as returned MakeTuple's element.
148 if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && fg->output() == user) {
149 MS_LOG(DEBUG) << "Found only returned parameter[" << index << "] at output index[" << pos << "] of "
150 << user->DebugString();
151 (void)only_return_parameter_indexes.emplace(pos, index);
152 (void)unused_parameter_indexes.emplace(index);
153 // Erase the unused element in returned MakeTuple CNode.
154 auto user_cnode = dyn_cast<CNode>(user);
155 MS_EXCEPTION_IF_NULL(user_cnode);
156 auto zero_value = NewValueNode(MakeValue<int64_t>(0));
157 zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0)));
158 user_cnode->set_input(IntToSize(pos), zero_value);
159 }
160 }
161 index++;
162 }
163 // Erase unused parameters.
164 std::vector<AnfNodePtr> new_parameters;
165 const auto &var_arg_node = fg->GetVariableArgParameter();
166 const auto &kw_arg_node = fg->GetVariableKwargParameter();
167 const auto &kw_only_args = fg->GetKwOnlyArgsParameters();
168 const size_t fv_position = parameters.size() - fg->fv_param_count();
169 for (size_t i = 0; i < parameters.size(); i++) {
170 const auto ¶m_i = parameters[i];
171 if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
172 (void)new_parameters.emplace_back(param_i);
173 } else {
174 // VarArgs, KwArgs, KwOnlyArgs may not following the index as the Positional Arguments.
175 if (param_i == var_arg_node) {
176 fg->set_has_vararg(false);
177 (void)unused_parameter_indexes.erase(i);
178 } else if (param_i == kw_arg_node) {
179 fg->set_has_kwarg(false);
180 (void)unused_parameter_indexes.erase(i);
181 } else {
182 bool is_kw_only_arg = std::any_of(kw_only_args.cbegin(), kw_only_args.cend(),
183 [param_i](const auto &kw_only_arg) { return kw_only_arg == param_i; });
184 if (is_kw_only_arg) {
185 if (fg->kwonlyargs_count() <= 0) {
186 MS_LOG(INTERNAL_EXCEPTION) << "The kw_only_args_count is 0 when a kw_only_arg should be removed";
187 }
188 fg->set_kwonlyargs_count(fg->kwonlyargs_count() - 1);
189 (void)unused_parameter_indexes.erase(i);
190 }
191 }
192 if (i >= fv_position) {
193 fg->set_fv_param_count(fg->fv_param_count() - 1);
194 }
195 MS_LOG(DEBUG) << "Erase parameter: " << param_i->DebugString() << ", index: " << i;
196 }
197 }
198 manager->SetParameters(fg, new_parameters);
199 return {unused_parameter_indexes, only_return_parameter_indexes};
200 }
201
202 // Adjust the call arguments of func graph whose parameter's eliminated.
AdjustCallerArgs(const FuncGraphPtr & called,const CNodePtr & caller,const mindspore::HashSet<size_t> & unused_parameter_indexes)203 static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &caller,
204 const mindspore::HashSet<size_t> &unused_parameter_indexes) {
205 size_t arg_start_index = 1;
206 MS_EXCEPTION_IF_NULL(caller->func_graph());
207 const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
208 MS_EXCEPTION_IF_NULL(manager);
209 std::vector<AnfNodePtr> new_args = {caller->input(0)};
210 if (IsPrimitiveCNode(caller, prim::kPrimPartial)) {
211 (void)new_args.emplace_back(caller->input(1));
212 arg_start_index = arg_start_index + 1;
213 }
214 for (size_t i = 0; i < caller->size() - arg_start_index; i++) {
215 if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
216 (void)new_args.emplace_back(caller->input(i + arg_start_index));
217 } else {
218 MS_LOG(DEBUG) << "Erase arg: " << caller->input(i + arg_start_index)->DebugString();
219 }
220 }
221 // Remove any Args which may be packed into VarArgs if VarArgs is not used in called FuncGraph;
222 // Note: 1. If there is any *args or key=value argument in call site, it will be converted to unpack_call
223 // CNode. So in this direct call case, all arguments should be plain arguments.
224 // 2. The arguments in caller may be less than the formal parameters in called as some parameters can have
225 // default value.
226 if (!called->has_vararg() &&
227 caller->size() > (1 + IntToSize(called->GetPositionalArgsCount()) + called->fv_param_count())) {
228 size_t start_offset = IntToSize(called->GetPositionalArgsCount()) + arg_start_index;
229 size_t end_offset = called->fv_param_count();
230 if (start_offset > new_args.size()) {
231 MS_LOG(INTERNAL_EXCEPTION) << "The start_offset is " << start_offset << ", which exceeds the number of new args "
232 << new_args.size() << ".";
233 }
234 (void)new_args.erase(new_args.cbegin() + SizeToLong(start_offset), new_args.cend() - SizeToLong(end_offset));
235 }
236
237 TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
238 auto new_caller = caller->func_graph()->NewCNode(new_args);
239 new_caller->set_abstract(caller->abstract());
240 // Should be done before manager. Replace as caller CNode will be dropped after Replace, the ReplaceInOrder will be
241 // no effect.
242 caller->func_graph()->ReplaceInOrder(caller, new_caller);
243 (void)manager->Replace(caller, new_caller);
244 }
245
246 // Adjust the caller(returned tuple)'s caller(getitem call)'s caller of func graph.
247 // Since the elements in returned tuple maybe eliminated,
248 // we should convert getitem(returned_tuple, x) into the eliminating argument itself.
AdjustGetItemCall(const CNodePtr & caller,const mindspore::HashMap<size_t,size_t> & only_return_parameter_indexes)249 static inline void AdjustGetItemCall(const CNodePtr &caller,
250 const mindspore::HashMap<size_t, size_t> &only_return_parameter_indexes) {
251 MS_EXCEPTION_IF_NULL(caller->func_graph());
252 const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
253 MS_EXCEPTION_IF_NULL(manager);
254 if (only_return_parameter_indexes.empty()) {
255 return;
256 }
257 const auto &node_users = manager->node_users();
258 const auto &iter = node_users.find(caller);
259 if (iter == node_users.end() || iter->second.empty()) {
260 return;
261 }
262 std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replacing_nodes;
263 auto &all_users = iter->second;
264 for (auto &user : all_users) {
265 auto node = user.first;
266 MS_EXCEPTION_IF_NULL(node);
267 if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
268 MS_LOG(ERROR) << "We expect a GetItem from the return tuple, but got " << node->DebugString();
269 continue;
270 }
271 auto getitem_cnode = dyn_cast<CNode>(node);
272 MS_EXCEPTION_IF_NULL(getitem_cnode);
273 // Check if it's the eliminated element of returned tuple.
274 constexpr size_t getitem_index_pos = 2;
275 auto &index_node = getitem_cnode->input(getitem_index_pos);
276 auto index_value = GetValueNode<Int64ImmPtr>(index_node);
277 if (index_value == nullptr || index_value->value() < 0) {
278 MS_LOG(INTERNAL_EXCEPTION) << "The index_value is incorrect, " << index_node->DebugString();
279 }
280 size_t index_value_imm = LongToSize(index_value->value());
281 const auto &index_pos = only_return_parameter_indexes.find(index_value_imm + 1);
282 if (index_pos == only_return_parameter_indexes.end()) {
283 continue;
284 }
285
286 // Found the tuple element, to replace it.
287 auto eliminating_argument_pos = index_pos->second;
288 MS_LOG(DEBUG) << "Found unused getitem CNode: " << getitem_cnode->DebugString() << ", index: " << index_value_imm
289 << ", eliminating_argument_pos: " << eliminating_argument_pos;
290 // Replace the getitem CNode with the eliminated argument.
291 auto &arg = caller->input(eliminating_argument_pos + 1);
292 (void)replacing_nodes.emplace_back(std::pair(getitem_cnode, arg));
293 }
294 for (auto &nodes : replacing_nodes) {
295 MS_LOG(DEBUG) << "Replace: " << nodes.first->DebugString() << ", with: " << nodes.second->DebugString();
296 (void)manager->Replace(nodes.first, nodes.second);
297 }
298 }
299
300 class ParameterEliminator {
301 public:
302 ParameterEliminator() = default;
303 virtual ~ParameterEliminator() = default;
operator()304 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &) {
305 bool changes = false;
306 while (true) {
307 const auto &[fg, callers] = SearchFuncGraphCallers(func_graph, eliminate_only_returned_parameter_);
308 if (fg == nullptr) {
309 break;
310 }
311 const auto &[unused_parameter_indexes, only_return_parameter_indexes] =
312 EraseUnusedParameters(fg, eliminate_only_returned_parameter_);
313 for (auto caller : callers) {
314 MS_LOG(DEBUG) << "caller: " << caller->DebugString();
315 // Replace the getitem CNodes with the arguments.
316 if (eliminate_only_returned_parameter_) {
317 AdjustGetItemCall(caller, only_return_parameter_indexes);
318 }
319 // Erase the arguments for eliminated parameters.
320 AdjustCallerArgs(fg, caller, unused_parameter_indexes);
321 }
322 changes = true;
323 }
324 return changes;
325 }
326
set_eliminate_only_returned_parameter(bool eliminate_only_returned_parameter)327 void set_eliminate_only_returned_parameter(bool eliminate_only_returned_parameter) {
328 eliminate_only_returned_parameter_ = eliminate_only_returned_parameter;
329 }
330
331 private:
332 bool eliminate_only_returned_parameter_{false};
333 };
334
335 class PartialUnusedArgsEliminate {
336 public:
337 PartialUnusedArgsEliminate() = default;
338 virtual ~PartialUnusedArgsEliminate() = default;
operator()339 bool operator()(const FuncGraphPtr &func_graph) {
340 MS_EXCEPTION_IF_NULL(func_graph);
341 auto manager = func_graph->manager();
342 MS_EXCEPTION_IF_NULL(manager);
343 bool changed = false;
344 auto fgs = func_graph->func_graphs_used_total();
345 for (const auto &fg : fgs) {
346 MS_EXCEPTION_IF_NULL(fg);
347 std::vector<CNodePtr> partial_nodes;
348 if (!GetUserPartialNodes(fg, &partial_nodes)) {
349 continue;
350 }
351 std::vector<size_t> unused_parameter_idx;
352 std::vector<AnfNodePtr> new_parameters;
353 const auto &node_users = manager->node_users();
354 const auto &origin_parameters = fg->parameters();
355 bool added_forward_u = fg->has_flag(kFuncGraphFlagAddedForwardU);
356 AnfNodePtr unused_arg_u = nullptr;
357 for (size_t i = 0; i < origin_parameters.size(); ++i) {
358 auto origin_para = origin_parameters[i];
359 auto iter = node_users.find(origin_para);
360 // Currently, we don't eliminate the function parameter node because it will produce DeadNode after renormalize.
361 if (!HasAbstractFunction(origin_para) && (iter == node_users.end() || iter->second.empty())) {
362 (void)unused_parameter_idx.emplace_back(i);
363 } else if (added_forward_u && HasAbstractUMonad(origin_para) && i < origin_parameters.size() - 1) {
364 // The fv u monad from fprop should be replaced with the forward u added by pass 'add_forward_monad_depend.h'.
365 (void)unused_parameter_idx.emplace_back(i);
366 unused_arg_u = origin_para;
367 } else {
368 (void)new_parameters.emplace_back(origin_para);
369 }
370 }
371 if (unused_parameter_idx.empty()) {
372 continue;
373 }
374 mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl;
375 if (!GetPartialRepl(partial_nodes, unused_parameter_idx, &repl)) {
376 continue;
377 }
378 if (unused_arg_u != nullptr) {
379 (void)manager->Replace(unused_arg_u, origin_parameters[origin_parameters.size() - 1]);
380 }
381 fg->set_parameters(new_parameters);
382 auto tr = manager->Transact();
383 for (auto &item : repl) {
384 (void)tr.Replace(item.first, item.second);
385 }
386 tr.Commit();
387 changed = true;
388 }
389 return changed;
390 }
391
392 private:
HasAbstractFunction(const AnfNodePtr & node)393 static bool HasAbstractFunction(const AnfNodePtr &node) {
394 return node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractFunction>();
395 }
396
GetUserPartialNodes(const FuncGraphPtr & fg,std::vector<CNodePtr> * partial_nodes)397 static bool GetUserPartialNodes(const FuncGraphPtr &fg, std::vector<CNodePtr> *partial_nodes) {
398 for (const auto &node_and_idx : fg->func_graph_cnodes_index()) {
399 auto user_node = node_and_idx.first->first;
400 if (!IsPrimitiveCNode(user_node, prim::kPrimPartial)) {
401 return false;
402 }
403 (void)partial_nodes->emplace_back(user_node->cast<CNodePtr>());
404 }
405 return true;
406 }
407
GetPartialRepl(const std::vector<CNodePtr> & partial_nodes,const std::vector<size_t> & unused_parameter_idx,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl)408 static bool GetPartialRepl(const std::vector<CNodePtr> &partial_nodes,
409 const std::vector<size_t> &unused_parameter_idx,
410 mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl) {
411 constexpr auto kPartialFirstArgIndex = 2;
412 for (const auto &partial : partial_nodes) {
413 const auto &origin_partial_inputs = partial->inputs();
414 std::vector<AnfNodePtr> new_partial_inputs;
415 size_t j = 0;
416 for (size_t i = 0; i < origin_partial_inputs.size(); ++i) {
417 if (j < unused_parameter_idx.size() && i >= kPartialFirstArgIndex &&
418 i - kPartialFirstArgIndex == unused_parameter_idx[j]) {
419 ++j;
420 continue;
421 } else {
422 (void)new_partial_inputs.emplace_back(origin_partial_inputs[i]);
423 }
424 }
425 // The unused parameter should be one of the partial inputs.
426 if (j < unused_parameter_idx.size()) {
427 return false;
428 }
429 auto partial_fg = partial->func_graph();
430 MS_EXCEPTION_IF_NULL(partial_fg);
431 auto new_partial = partial_fg->NewCNode(new_partial_inputs);
432 (void)repl->emplace(partial, new_partial);
433 }
434 return true;
435 }
436 };
437 } // namespace irpass
438 } // namespace opt
439 } // namespace mindspore
440 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
441