• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
19 #include <vector>
20 #include <utility>
21 #include <memory>
22 
23 #include "utils/hash_set.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "frontend/optimizer/irpass.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph.h"
31 #include "frontend/operator/ops.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace irpass {
CheckSwitchCallValid(const CNodePtr & switch_call)36 static inline void CheckSwitchCallValid(const CNodePtr &switch_call) {
37   if (switch_call->size() > 1) {
38     // Means call switch(arg1, ...) has args.
39     constexpr auto recursive_count = 2;
40     MS_LOG(INTERNAL_EXCEPTION) << "After switch_call_monad_eliminater pass, the call switch node should not has args."
41                                << " The call_switch_cnode is: " << switch_call->DebugString(recursive_count);
42   }
43 }
44 
GetCallers(const FuncGraphPtr & fg)45 static inline std::vector<CNodePtr> GetCallers(const FuncGraphPtr &fg) {
46   MS_EXCEPTION_IF_NULL(fg);
47   const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
48   std::vector<CNodePtr> caller_cnodes = {};
49   // Find all caller of fg.
50   auto manager = fg->manager();
51   MS_EXCEPTION_IF_NULL(manager);
52   auto &node_users = manager->node_users();
53   for (const auto &it : fg_caller_and_indexes) {
54     const auto &fg_caller_and_index = it.first;
55     auto caller_cnode = fg_caller_and_index->first;
56     auto index = fg_caller_and_index->second;
57     // If index != 0, the caller is a indirect caller, can't erase the parameter of graph.
58     // Because in this situation ValueNode<FuncGraph> is a input of Return or of MakeTuple.
59     MS_LOG(DEBUG) << "index: " << index;
60     // Process has partial func_graph with Primitive
61     // %1 = Partial(func_graph, arg1, arg2, ...)
62     if (index == 1 && IsPrimitiveCNode(caller_cnode, prim::kPrimPartial)) {
63       auto iter = node_users.find(caller_cnode);
64       for (auto &user : iter->second) {
65         auto &user_node = user.first;
66         auto user_cnode = user_node->cast<CNodePtr>();
67         // Check user of partial (switch), the numbers of args should be 0.
68         if (IsPrimitiveCNode(user_cnode, prim::kPrimSwitch)) {
69           // Call switch()
70           auto call_switchs = node_users[user_cnode];
71           for (auto call_switch_iter : call_switchs) {
72             CheckSwitchCallValid(call_switch_iter.first->cast<CNodePtr>());
73           }
74           if (std::find(caller_cnodes.begin(), caller_cnodes.end(), caller_cnode) == caller_cnodes.end()) {
75             (void)caller_cnodes.emplace_back(caller_cnode->cast<CNodePtr>());
76           }
77         }
78       }
79     } else if (index != 0) {
80       return {};
81     } else {
82       // Process call func_graph: %1 = func_graph(arg1, arg2, ...)
83       (void)caller_cnodes.emplace_back(caller_cnode->cast<CNodePtr>());
84     }
85   }
86   return caller_cnodes;
87 }
88 
SearchFuncGraphCallers(const FuncGraphPtr & func_graph,bool eliminate_only_returned_parameter)89 static inline std::pair<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(
90   const FuncGraphPtr &func_graph, bool eliminate_only_returned_parameter) {
91   for (const auto &fg : func_graph->func_graphs_used_total()) {
92     if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->has_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH)) {
93       continue;
94     }
95     const auto &parameters = fg->parameters();
96     MS_EXCEPTION_IF_NULL(fg->manager());
97     const auto &manager_node_users = fg->manager()->node_users();
98     // Check if no user parameter or only one user in output tuple.
99     bool exist_param_unused =
100       std::any_of(parameters.begin(), parameters.end(),
101                   [&manager_node_users, &fg, eliminate_only_returned_parameter](const AnfNodePtr &parameter) {
102                     const auto &node_users_it = manager_node_users.find(parameter);
103                     // No user parameter.
104                     if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
105                       return true;
106                     }
107                     // We will check the tuple output, if only one user.
108                     if (eliminate_only_returned_parameter && fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) &&
109                         node_users_it->second.size() == 1) {
110                       auto user = node_users_it->second.begin()->first;
111                       // The parameter only used as returned MakeTuple's element.
112                       if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && fg->output() == user) {
113                         return true;
114                       }
115                     }
116                     return false;
117                   });
118     if (exist_param_unused) {
119       const auto &callers = GetCallers(fg);
120       if (!callers.empty()) {
121         return {fg, callers};
122       }
123     }
124   }
125   return {nullptr, {}};
126 }
127 
EraseUnusedParameters(const FuncGraphPtr & fg,bool eliminate_only_returned_parameter)128 static inline std::pair<mindspore::HashSet<size_t>, mindspore::HashMap<size_t, size_t>> EraseUnusedParameters(
129   const FuncGraphPtr &fg, bool eliminate_only_returned_parameter) {
130   MS_EXCEPTION_IF_NULL(fg);
131   const FuncGraphManagerPtr &manager = fg->manager();
132   MS_EXCEPTION_IF_NULL(manager);
133   const auto &manager_node_users = manager->node_users();
134   const auto &parameters = fg->parameters();
135   mindspore::HashSet<size_t> unused_parameter_indexes;
136   mindspore::HashMap<size_t, size_t> only_return_parameter_indexes;
137   // Traverse to find all unused parameters.
138   size_t index = 0;
139   for (const auto &parameter : parameters) {
140     const auto &node_users_it = manager_node_users.find(parameter);
141     if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
142       (void)unused_parameter_indexes.emplace(index);
143     } else if (eliminate_only_returned_parameter && fg->has_flag(FUNC_GRAPH_FLAG_NO_INLINE) &&
144                node_users_it->second.size() == 1) {
145       auto user = node_users_it->second.begin()->first;
146       auto pos = node_users_it->second.begin()->second;
147       // The parameter only used as returned MakeTuple's element.
148       if (IsPrimitiveCNode(user, prim::kPrimMakeTuple) && fg->output() == user) {
149         MS_LOG(DEBUG) << "Found only returned parameter[" << index << "] at output index[" << pos << "] of "
150                       << user->DebugString();
151         (void)only_return_parameter_indexes.emplace(pos, index);
152         (void)unused_parameter_indexes.emplace(index);
153         // Erase the unused element in returned MakeTuple CNode.
154         auto user_cnode = dyn_cast<CNode>(user);
155         MS_EXCEPTION_IF_NULL(user_cnode);
156         auto zero_value = NewValueNode(MakeValue<int64_t>(0));
157         zero_value->set_abstract(std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(0)));
158         user_cnode->set_input(IntToSize(pos), zero_value);
159       }
160     }
161     index++;
162   }
163   // Erase unused parameters.
164   std::vector<AnfNodePtr> new_parameters;
165   const auto &var_arg_node = fg->GetVariableArgParameter();
166   const auto &kw_arg_node = fg->GetVariableKwargParameter();
167   const auto &kw_only_args = fg->GetKwOnlyArgsParameters();
168   const size_t fv_position = parameters.size() - fg->fv_param_count();
169   for (size_t i = 0; i < parameters.size(); i++) {
170     const auto &param_i = parameters[i];
171     if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
172       (void)new_parameters.emplace_back(param_i);
173     } else {
174       // VarArgs, KwArgs, KwOnlyArgs may not following the index as the Positional Arguments.
175       if (param_i == var_arg_node) {
176         fg->set_has_vararg(false);
177         (void)unused_parameter_indexes.erase(i);
178       } else if (param_i == kw_arg_node) {
179         fg->set_has_kwarg(false);
180         (void)unused_parameter_indexes.erase(i);
181       } else {
182         bool is_kw_only_arg = std::any_of(kw_only_args.cbegin(), kw_only_args.cend(),
183                                           [param_i](const auto &kw_only_arg) { return kw_only_arg == param_i; });
184         if (is_kw_only_arg) {
185           if (fg->kwonlyargs_count() <= 0) {
186             MS_LOG(INTERNAL_EXCEPTION) << "The kw_only_args_count is 0 when a kw_only_arg should be removed";
187           }
188           fg->set_kwonlyargs_count(fg->kwonlyargs_count() - 1);
189           (void)unused_parameter_indexes.erase(i);
190         }
191       }
192       if (i >= fv_position) {
193         fg->set_fv_param_count(fg->fv_param_count() - 1);
194       }
195       MS_LOG(DEBUG) << "Erase parameter: " << param_i->DebugString() << ", index: " << i;
196     }
197   }
198   manager->SetParameters(fg, new_parameters);
199   return {unused_parameter_indexes, only_return_parameter_indexes};
200 }
201 
202 // Adjust the call arguments of func graph whose parameter's eliminated.
AdjustCallerArgs(const FuncGraphPtr & called,const CNodePtr & caller,const mindspore::HashSet<size_t> & unused_parameter_indexes)203 static inline void AdjustCallerArgs(const FuncGraphPtr &called, const CNodePtr &caller,
204                                     const mindspore::HashSet<size_t> &unused_parameter_indexes) {
205   size_t arg_start_index = 1;
206   MS_EXCEPTION_IF_NULL(caller->func_graph());
207   const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
208   MS_EXCEPTION_IF_NULL(manager);
209   std::vector<AnfNodePtr> new_args = {caller->input(0)};
210   if (IsPrimitiveCNode(caller, prim::kPrimPartial)) {
211     (void)new_args.emplace_back(caller->input(1));
212     arg_start_index = arg_start_index + 1;
213   }
214   for (size_t i = 0; i < caller->size() - arg_start_index; i++) {
215     if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
216       (void)new_args.emplace_back(caller->input(i + arg_start_index));
217     } else {
218       MS_LOG(DEBUG) << "Erase arg: " << caller->input(i + arg_start_index)->DebugString();
219     }
220   }
221   // Remove any Args which may be packed into VarArgs if VarArgs is not used in called FuncGraph;
222   // Note: 1. If there is any *args or key=value argument in call site, it will be converted to unpack_call
223   // CNode. So in this direct call case, all arguments should be plain arguments.
224   //       2. The arguments in caller may be less than the formal parameters in called as some parameters can have
225   //       default value.
226   if (!called->has_vararg() &&
227       caller->size() > (1 + IntToSize(called->GetPositionalArgsCount()) + called->fv_param_count())) {
228     size_t start_offset = IntToSize(called->GetPositionalArgsCount()) + arg_start_index;
229     size_t end_offset = called->fv_param_count();
230     if (start_offset > new_args.size()) {
231       MS_LOG(INTERNAL_EXCEPTION) << "The start_offset is " << start_offset << ", which exceeds the number of new args "
232                                  << new_args.size() << ".";
233     }
234     (void)new_args.erase(new_args.cbegin() + SizeToLong(start_offset), new_args.cend() - SizeToLong(end_offset));
235   }
236 
237   TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
238   auto new_caller = caller->func_graph()->NewCNode(new_args);
239   new_caller->set_abstract(caller->abstract());
240   // Should be done before manager. Replace as caller CNode will be dropped after Replace, the ReplaceInOrder will be
241   // no effect.
242   caller->func_graph()->ReplaceInOrder(caller, new_caller);
243   (void)manager->Replace(caller, new_caller);
244 }
245 
246 // Adjust the caller(returned tuple)'s caller(getitem call)'s caller of func graph.
247 // Since the elements in returned tuple maybe eliminated,
248 // we should convert getitem(returned_tuple, x) into the eliminating argument itself.
AdjustGetItemCall(const CNodePtr & caller,const mindspore::HashMap<size_t,size_t> & only_return_parameter_indexes)249 static inline void AdjustGetItemCall(const CNodePtr &caller,
250                                      const mindspore::HashMap<size_t, size_t> &only_return_parameter_indexes) {
251   MS_EXCEPTION_IF_NULL(caller->func_graph());
252   const FuncGraphManagerPtr &manager = caller->func_graph()->manager();
253   MS_EXCEPTION_IF_NULL(manager);
254   if (only_return_parameter_indexes.empty()) {
255     return;
256   }
257   const auto &node_users = manager->node_users();
258   const auto &iter = node_users.find(caller);
259   if (iter == node_users.end() || iter->second.empty()) {
260     return;
261   }
262   std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replacing_nodes;
263   auto &all_users = iter->second;
264   for (auto &user : all_users) {
265     auto node = user.first;
266     MS_EXCEPTION_IF_NULL(node);
267     if (!IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
268       MS_LOG(ERROR) << "We expect a GetItem from the return tuple, but got " << node->DebugString();
269       continue;
270     }
271     auto getitem_cnode = dyn_cast<CNode>(node);
272     MS_EXCEPTION_IF_NULL(getitem_cnode);
273     // Check if it's the eliminated element of returned tuple.
274     constexpr size_t getitem_index_pos = 2;
275     auto &index_node = getitem_cnode->input(getitem_index_pos);
276     auto index_value = GetValueNode<Int64ImmPtr>(index_node);
277     if (index_value == nullptr || index_value->value() < 0) {
278       MS_LOG(INTERNAL_EXCEPTION) << "The index_value is incorrect, " << index_node->DebugString();
279     }
280     size_t index_value_imm = LongToSize(index_value->value());
281     const auto &index_pos = only_return_parameter_indexes.find(index_value_imm + 1);
282     if (index_pos == only_return_parameter_indexes.end()) {
283       continue;
284     }
285 
286     // Found the tuple element, to replace it.
287     auto eliminating_argument_pos = index_pos->second;
288     MS_LOG(DEBUG) << "Found unused getitem CNode: " << getitem_cnode->DebugString() << ", index: " << index_value_imm
289                   << ", eliminating_argument_pos: " << eliminating_argument_pos;
290     // Replace the getitem CNode with the eliminated argument.
291     auto &arg = caller->input(eliminating_argument_pos + 1);
292     (void)replacing_nodes.emplace_back(std::pair(getitem_cnode, arg));
293   }
294   for (auto &nodes : replacing_nodes) {
295     MS_LOG(DEBUG) << "Replace: " << nodes.first->DebugString() << ", with: " << nodes.second->DebugString();
296     (void)manager->Replace(nodes.first, nodes.second);
297   }
298 }
299 
300 class ParameterEliminator {
301  public:
302   ParameterEliminator() = default;
303   virtual ~ParameterEliminator() = default;
operator()304   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &) {
305     bool changes = false;
306     while (true) {
307       const auto &[fg, callers] = SearchFuncGraphCallers(func_graph, eliminate_only_returned_parameter_);
308       if (fg == nullptr) {
309         break;
310       }
311       const auto &[unused_parameter_indexes, only_return_parameter_indexes] =
312         EraseUnusedParameters(fg, eliminate_only_returned_parameter_);
313       for (auto caller : callers) {
314         MS_LOG(DEBUG) << "caller: " << caller->DebugString();
315         // Replace the getitem CNodes with the arguments.
316         if (eliminate_only_returned_parameter_) {
317           AdjustGetItemCall(caller, only_return_parameter_indexes);
318         }
319         // Erase the arguments for eliminated parameters.
320         AdjustCallerArgs(fg, caller, unused_parameter_indexes);
321       }
322       changes = true;
323     }
324     return changes;
325   }
326 
set_eliminate_only_returned_parameter(bool eliminate_only_returned_parameter)327   void set_eliminate_only_returned_parameter(bool eliminate_only_returned_parameter) {
328     eliminate_only_returned_parameter_ = eliminate_only_returned_parameter;
329   }
330 
331  private:
332   bool eliminate_only_returned_parameter_{false};
333 };
334 
335 class PartialUnusedArgsEliminate {
336  public:
337   PartialUnusedArgsEliminate() = default;
338   virtual ~PartialUnusedArgsEliminate() = default;
operator()339   bool operator()(const FuncGraphPtr &func_graph) {
340     MS_EXCEPTION_IF_NULL(func_graph);
341     auto manager = func_graph->manager();
342     MS_EXCEPTION_IF_NULL(manager);
343     bool changed = false;
344     auto fgs = func_graph->func_graphs_used_total();
345     for (const auto &fg : fgs) {
346       MS_EXCEPTION_IF_NULL(fg);
347       std::vector<CNodePtr> partial_nodes;
348       if (!GetUserPartialNodes(fg, &partial_nodes)) {
349         continue;
350       }
351       std::vector<size_t> unused_parameter_idx;
352       std::vector<AnfNodePtr> new_parameters;
353       const auto &node_users = manager->node_users();
354       const auto &origin_parameters = fg->parameters();
355       bool added_forward_u = fg->has_flag(kFuncGraphFlagAddedForwardU);
356       AnfNodePtr unused_arg_u = nullptr;
357       for (size_t i = 0; i < origin_parameters.size(); ++i) {
358         auto origin_para = origin_parameters[i];
359         auto iter = node_users.find(origin_para);
360         // Currently, we don't eliminate the function parameter node because it will produce DeadNode after renormalize.
361         if (!HasAbstractFunction(origin_para) && (iter == node_users.end() || iter->second.empty())) {
362           (void)unused_parameter_idx.emplace_back(i);
363         } else if (added_forward_u && HasAbstractUMonad(origin_para) && i < origin_parameters.size() - 1) {
364           // The fv u monad from fprop should be replaced with the forward u added by pass 'add_forward_monad_depend.h'.
365           (void)unused_parameter_idx.emplace_back(i);
366           unused_arg_u = origin_para;
367         } else {
368           (void)new_parameters.emplace_back(origin_para);
369         }
370       }
371       if (unused_parameter_idx.empty()) {
372         continue;
373       }
374       mindspore::HashMap<AnfNodePtr, AnfNodePtr> repl;
375       if (!GetPartialRepl(partial_nodes, unused_parameter_idx, &repl)) {
376         continue;
377       }
378       if (unused_arg_u != nullptr) {
379         (void)manager->Replace(unused_arg_u, origin_parameters[origin_parameters.size() - 1]);
380       }
381       fg->set_parameters(new_parameters);
382       auto tr = manager->Transact();
383       for (auto &item : repl) {
384         (void)tr.Replace(item.first, item.second);
385       }
386       tr.Commit();
387       changed = true;
388     }
389     return changed;
390   }
391 
392  private:
HasAbstractFunction(const AnfNodePtr & node)393   static bool HasAbstractFunction(const AnfNodePtr &node) {
394     return node->abstract() != nullptr && node->abstract()->isa<abstract::AbstractFunction>();
395   }
396 
GetUserPartialNodes(const FuncGraphPtr & fg,std::vector<CNodePtr> * partial_nodes)397   static bool GetUserPartialNodes(const FuncGraphPtr &fg, std::vector<CNodePtr> *partial_nodes) {
398     for (const auto &node_and_idx : fg->func_graph_cnodes_index()) {
399       auto user_node = node_and_idx.first->first;
400       if (!IsPrimitiveCNode(user_node, prim::kPrimPartial)) {
401         return false;
402       }
403       (void)partial_nodes->emplace_back(user_node->cast<CNodePtr>());
404     }
405     return true;
406   }
407 
GetPartialRepl(const std::vector<CNodePtr> & partial_nodes,const std::vector<size_t> & unused_parameter_idx,mindspore::HashMap<AnfNodePtr,AnfNodePtr> * repl)408   static bool GetPartialRepl(const std::vector<CNodePtr> &partial_nodes,
409                              const std::vector<size_t> &unused_parameter_idx,
410                              mindspore::HashMap<AnfNodePtr, AnfNodePtr> *repl) {
411     constexpr auto kPartialFirstArgIndex = 2;
412     for (const auto &partial : partial_nodes) {
413       const auto &origin_partial_inputs = partial->inputs();
414       std::vector<AnfNodePtr> new_partial_inputs;
415       size_t j = 0;
416       for (size_t i = 0; i < origin_partial_inputs.size(); ++i) {
417         if (j < unused_parameter_idx.size() && i >= kPartialFirstArgIndex &&
418             i - kPartialFirstArgIndex == unused_parameter_idx[j]) {
419           ++j;
420           continue;
421         } else {
422           (void)new_partial_inputs.emplace_back(origin_partial_inputs[i]);
423         }
424       }
425       // The unused parameter should be one of the partial inputs.
426       if (j < unused_parameter_idx.size()) {
427         return false;
428       }
429       auto partial_fg = partial->func_graph();
430       MS_EXCEPTION_IF_NULL(partial_fg);
431       auto new_partial = partial_fg->NewCNode(new_partial_inputs);
432       (void)repl->emplace(partial, new_partial);
433     }
434     return true;
435   }
436 };
437 }  // namespace irpass
438 }  // namespace opt
439 }  // namespace mindspore
440 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
441