• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
19 #include <vector>
20 #include <utility>
21 #include <unordered_set>
22 #include <memory>
23 
24 #include "frontend/optimizer/irpass.h"
25 #include "frontend/optimizer/optimizer.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 #include "ir/manager.h"
28 #include "ir/func_graph.h"
29 #include "frontend/operator/ops.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace irpass {
34 
35 class ParameterEliminator {
36  public:
37   ParameterEliminator() = default;
38   virtual ~ParameterEliminator() = default;
operator()39   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
40     const auto &manager = func_graph->manager();
41     MS_EXCEPTION_IF_NULL(manager);
42     bool changes = false;
43     while (true) {
44       const auto &[fg, callers] = SearchFuncGraphCallers(func_graph);
45       if (fg == nullptr) {
46         break;
47       }
48       auto manager = fg->manager();
49       MS_EXCEPTION_IF_NULL(manager);
50       const auto &erase_indexes = EraseUnusedParameters(fg, manager);
51       for (auto caller : callers) {
52         // Erase the corresponding args.
53         EraseArgs(caller, erase_indexes, manager);
54       }
55       changes = true;
56     }
57     return changes;
58   }
59 
60  private:
GetCallers(const FuncGraphPtr & fg)61   static std::vector<CNodePtr> GetCallers(const FuncGraphPtr &fg) {
62     const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index();
63     std::vector<CNodePtr> caller_cnodes = {};
64     // Find all caller of fg.
65     for (const auto &it : fg_caller_and_indexes) {
66       const auto &fg_caller_and_index = it.first;
67       auto caller_cnode = fg_caller_and_index->first;
68       auto index = fg_caller_and_index->second;
69       // If index != 0, the caller is a indirect caller, can't erase the parameter of graph.Because
70       // in this situation ValueNode<FuncGraph> is a input of Return or of MakeTuple.
71       if (index != 0) {
72         return {};
73       }
74       caller_cnodes.push_back(caller_cnode->cast<CNodePtr>());
75     }
76     return caller_cnodes;
77   }
78 
SearchFuncGraphCallers(const FuncGraphPtr & func_graph)79   static std::pair<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(const FuncGraphPtr &func_graph) {
80     for (const auto &fg : func_graph->func_graphs_used_total()) {
81       if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) {
82         continue;
83       }
84       const auto &parameters = fg->parameters();
85       MS_EXCEPTION_IF_NULL(fg->manager());
86       const auto &manager_node_users = fg->manager()->node_users();
87       bool exist_param_unused =
88         std::any_of(parameters.begin(), parameters.end(), [&manager_node_users](const AnfNodePtr &parameter) {
89           const auto &node_users_it = manager_node_users.find(parameter);
90           return node_users_it == manager_node_users.end() || node_users_it->second.empty();
91         });
92       if (exist_param_unused) {
93         const auto &callers = GetCallers(fg);
94         if (!callers.empty()) {
95           return {fg, callers};
96         }
97       }
98     }
99     return {nullptr, {}};
100   }
101 
EraseUnusedParameters(const FuncGraphPtr & fg,const FuncGraphManagerPtr & manager)102   static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) {
103     MS_EXCEPTION_IF_NULL(fg->manager());
104     const auto &manager_node_users = fg->manager()->node_users();
105     const auto &parameters = fg->parameters();
106     std::unordered_set<size_t> unused_parameter_indexes;
107     // Traverse to find all unused parameters.
108     size_t index = 0;
109     for (const auto &parameter : parameters) {
110       const auto &node_users_it = manager_node_users.find(parameter);
111       if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) {
112         unused_parameter_indexes.insert(index);
113       }
114       index++;
115     }
116     // Erase unused parameters.
117     std::vector<AnfNodePtr> new_parameters;
118     for (size_t i = 0; i < parameters.size(); i++) {
119       if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
120         new_parameters.push_back(parameters[i]);
121       } else {
122         MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i;
123       }
124     }
125     manager->SetParameters(fg, new_parameters);
126     return unused_parameter_indexes;
127   }
128 
EraseArgs(const CNodePtr & caller,const std::unordered_set<size_t> & unused_parameter_indexes,const FuncGraphManagerPtr & manager)129   static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes,
130                         const FuncGraphManagerPtr &manager) {
131     std::vector<AnfNodePtr> new_args = {caller->inputs()[0]};
132     for (size_t i = 0; i < caller->inputs().size() - 1; i++) {
133       if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) {
134         new_args.push_back(caller->inputs()[i + 1]);
135       } else {
136         MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i;
137       }
138     }
139     TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info()));
140     auto new_caller = caller->func_graph()->NewCNode(new_args);
141     new_caller->set_abstract(caller->abstract());
142     manager->Replace(caller, new_caller);
143   }
144 };
145 }  // namespace irpass
146 }  // namespace opt
147 }  // namespace mindspore
148 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H
149