1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H 19 #include <vector> 20 #include <utility> 21 #include <unordered_set> 22 #include <memory> 23 24 #include "frontend/optimizer/irpass.h" 25 #include "frontend/optimizer/optimizer.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 #include "ir/manager.h" 28 #include "ir/func_graph.h" 29 #include "frontend/operator/ops.h" 30 31 namespace mindspore { 32 namespace opt { 33 namespace irpass { 34 35 class ParameterEliminator { 36 public: 37 ParameterEliminator() = default; 38 virtual ~ParameterEliminator() = default; operator()39 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { 40 const auto &manager = func_graph->manager(); 41 MS_EXCEPTION_IF_NULL(manager); 42 bool changes = false; 43 while (true) { 44 const auto &[fg, callers] = SearchFuncGraphCallers(func_graph); 45 if (fg == nullptr) { 46 break; 47 } 48 auto manager = fg->manager(); 49 MS_EXCEPTION_IF_NULL(manager); 50 const auto &erase_indexes = EraseUnusedParameters(fg, manager); 51 for (auto caller : callers) { 52 // Erase the corresponding args. 53 EraseArgs(caller, erase_indexes, manager); 54 } 55 changes = true; 56 } 57 return changes; 58 } 59 60 private: GetCallers(const FuncGraphPtr & fg)61 static std::vector<CNodePtr> GetCallers(const FuncGraphPtr &fg) { 62 const auto &fg_caller_and_indexes = fg->func_graph_cnodes_index(); 63 std::vector<CNodePtr> caller_cnodes = {}; 64 // Find all caller of fg. 65 for (const auto &it : fg_caller_and_indexes) { 66 const auto &fg_caller_and_index = it.first; 67 auto caller_cnode = fg_caller_and_index->first; 68 auto index = fg_caller_and_index->second; 69 // If index != 0, the caller is a indirect caller, can't erase the parameter of graph.Because 70 // in this situation ValueNode<FuncGraph> is a input of Return or of MakeTuple. 71 if (index != 0) { 72 return {}; 73 } 74 caller_cnodes.push_back(caller_cnode->cast<CNodePtr>()); 75 } 76 return caller_cnodes; 77 } 78 SearchFuncGraphCallers(const FuncGraphPtr & func_graph)79 static std::pair<FuncGraphPtr, std::vector<CNodePtr>> SearchFuncGraphCallers(const FuncGraphPtr &func_graph) { 80 for (const auto &fg : func_graph->func_graphs_used_total()) { 81 if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE)) { 82 continue; 83 } 84 const auto ¶meters = fg->parameters(); 85 MS_EXCEPTION_IF_NULL(fg->manager()); 86 const auto &manager_node_users = fg->manager()->node_users(); 87 bool exist_param_unused = 88 std::any_of(parameters.begin(), parameters.end(), [&manager_node_users](const AnfNodePtr ¶meter) { 89 const auto &node_users_it = manager_node_users.find(parameter); 90 return node_users_it == manager_node_users.end() || node_users_it->second.empty(); 91 }); 92 if (exist_param_unused) { 93 const auto &callers = GetCallers(fg); 94 if (!callers.empty()) { 95 return {fg, callers}; 96 } 97 } 98 } 99 return {nullptr, {}}; 100 } 101 EraseUnusedParameters(const FuncGraphPtr & fg,const FuncGraphManagerPtr & manager)102 static std::unordered_set<size_t> EraseUnusedParameters(const FuncGraphPtr &fg, const FuncGraphManagerPtr &manager) { 103 MS_EXCEPTION_IF_NULL(fg->manager()); 104 const auto &manager_node_users = fg->manager()->node_users(); 105 const auto ¶meters = fg->parameters(); 106 std::unordered_set<size_t> unused_parameter_indexes; 107 // Traverse to find all unused parameters. 108 size_t index = 0; 109 for (const auto ¶meter : parameters) { 110 const auto &node_users_it = manager_node_users.find(parameter); 111 if (node_users_it == manager_node_users.end() || node_users_it->second.empty()) { 112 unused_parameter_indexes.insert(index); 113 } 114 index++; 115 } 116 // Erase unused parameters. 117 std::vector<AnfNodePtr> new_parameters; 118 for (size_t i = 0; i < parameters.size(); i++) { 119 if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) { 120 new_parameters.push_back(parameters[i]); 121 } else { 122 MS_LOG(DEBUG) << "Erase parameter:" << parameters[i]->DebugString() << ",index:" << i; 123 } 124 } 125 manager->SetParameters(fg, new_parameters); 126 return unused_parameter_indexes; 127 } 128 EraseArgs(const CNodePtr & caller,const std::unordered_set<size_t> & unused_parameter_indexes,const FuncGraphManagerPtr & manager)129 static void EraseArgs(const CNodePtr &caller, const std::unordered_set<size_t> &unused_parameter_indexes, 130 const FuncGraphManagerPtr &manager) { 131 std::vector<AnfNodePtr> new_args = {caller->inputs()[0]}; 132 for (size_t i = 0; i < caller->inputs().size() - 1; i++) { 133 if (unused_parameter_indexes.find(i) == unused_parameter_indexes.end()) { 134 new_args.push_back(caller->inputs()[i + 1]); 135 } else { 136 MS_LOG(DEBUG) << "Erase arg:" << caller->inputs()[i + 1]->DebugString() << ",index:" << i; 137 } 138 } 139 TraceGuard trace_guard(std::make_shared<TraceCopy>(caller->debug_info())); 140 auto new_caller = caller->func_graph()->NewCNode(new_args); 141 new_caller->set_abstract(caller->abstract()); 142 manager->Replace(caller, new_caller); 143 } 144 }; 145 } // namespace irpass 146 } // namespace opt 147 } // namespace mindspore 148 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARAMETER_ELIMINATE_H 149