• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/auto_monad_eliminate.h"
18 
19 #include <vector>
20 #include <unordered_set>
21 #include <unordered_map>
22 #include <algorithm>
23 #include <memory>
24 
25 #include "base/core_ops.h"
26 
27 namespace mindspore {
28 namespace opt {
29 using MapParamUserIndexs = std::unordered_map<AnfNodePtr, std::vector<size_t>>;
GenerateLoadGroups(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & toposet,std::vector<AnfNodePtr> * need_replace_loads,MapParamUserIndexs * unload_users_record,std::vector<size_t> * special_op_indexs)30 std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
31                                                     std::vector<AnfNodePtr> *need_replace_loads,
32                                                     MapParamUserIndexs *unload_users_record,
33                                                     std::vector<size_t> *special_op_indexs) {
34   std::unordered_map<AnfNodePtr, size_t> load_groups_record;
35   std::vector<std::vector<size_t>> load_groups;
36   for (size_t i = 0; i < toposet.size(); i++) {
37     auto &node = toposet[i];
38     auto cnode = node->cast<CNodePtr>();
39     // Exclude free variable node.
40     if (cnode == nullptr || cnode->func_graph() != fg) {
41       continue;
42     }
43     bool is_special_op = IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode<FuncGraph>(cnode->input(0)) ||
44                          IsPrimitiveCNode(cnode, prim::kPrimPartial) || IsPrimitiveCNode(cnode, prim::kPrimSwitch) ||
45                          IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer);
46     if (is_special_op) {
47       (void)special_op_indexs->emplace_back(i);
48     }
49 
50     // Record param user in toposort nodes.
51     if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
52       for (const auto &input : cnode->inputs()) {
53         AnfNodePtr cur_param = nullptr;
54         if (input->isa<Parameter>()) {
55           cur_param = input;
56         } else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>()) {
57           cur_param = input->cast<CNodePtr>()->input(1);
58         }
59         if (cur_param != nullptr) {
60           (void)(*unload_users_record)[cur_param].emplace_back(i);
61         }
62       }
63       continue;
64     }
65 
66     auto load_param = cnode->input(1);
67     // first time get same input1 of load.
68     if (load_groups_record.find(load_param) == load_groups_record.end()) {
69       load_groups_record[load_param] = load_groups.size();
70       load_groups.push_back({i});
71       // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u)
72       // Means there are not nodes which modify param before the load
73       bool can_replace = (*unload_users_record)[load_param].empty() && special_op_indexs->empty();
74       if (can_replace) {
75         need_replace_loads->emplace_back(cnode);
76       }
77     } else {
78       // not first time get same input1 of load
79       load_groups[load_groups_record[load_param]].push_back(i);
80     }
81   }
82   return load_groups;
83 }
84 
SplitGroup(const std::vector<size_t> & group,const std::vector<size_t> & unload_user_indexs,const std::vector<size_t> & special_op_indexs)85 std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
86                                             const std::vector<size_t> &unload_user_indexs,
87                                             const std::vector<size_t> &special_op_indexs) {
88   if (group.size() <= 1) {
89     return {};
90   }
91   size_t cur_load_index = 1;
92   size_t pre_load_index = 0;
93   std::vector<size_t> cur_group = {group[pre_load_index]};
94   std::vector<std::vector<size_t>> split_groups;
95   while (cur_load_index < group.size()) {
96     const auto &cur_load = group[cur_load_index];
97     const auto &prev_load = group[pre_load_index];
98     // Exist node which is the user of load_param between prev_load and cur_load,
99     // Do not divide into the same group.
100     const auto param_used_by_other =
101       std::any_of(unload_user_indexs.begin(), unload_user_indexs.end(),
102                   [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
103     const auto param_used_by_special_op =
104       std::any_of(special_op_indexs.begin(), special_op_indexs.end(),
105                   [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
106     if (param_used_by_other || param_used_by_special_op) {
107       split_groups.push_back(cur_group);
108       cur_group.clear();
109     }
110     cur_group.push_back(cur_load);
111     pre_load_index++;
112     cur_load_index++;
113   }
114   // push back the last splited group.
115   split_groups.push_back(cur_group);
116   return split_groups;
117 }
118 
119 // Pattern1======================================
120 // a = Load(para1, u1)
121 // ...
122 // b = Load(para1, u2)
123 // u3 = UpdateState(u2, b)
124 // ==>
125 // delete the UpdateState
DeleteLoadUserUpdateState(const FuncGraphManagerPtr & manager,const AnfNodePtr & load_user)126 void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
127   const auto &update_state_cnode = load_user->cast<CNodePtr>();
128   constexpr size_t monad_index = 1;
129   const auto &monad = update_state_cnode->input(monad_index);
130   (void)manager->Replace(load_user, monad);
131 }
132 
133 // Pattern2======================================
134 // a = Load(para1, u1)
135 // ...
136 // b = Load(para1, u2)
137 // t = make_tuple(x, b)
138 // u3 = UpdateState(u2, t)
139 //==>
140 // a = Load(para1, u1)
141 // ...
142 // b = Load(para1, u2)
143 // u3 = UpdateState(u2, x)
DeleteLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)144 void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
145   // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
146   AnfNodePtr other_input = load;
147   for (size_t i = 1; i < make_tuple->size(); i++) {
148     if (make_tuple->input(i) != load) {
149       other_input = make_tuple->input(i);
150       break;
151     }
152   }
153   MS_EXCEPTION_IF_NULL(other_input);
154   manager->Replace(make_tuple, other_input);
155 }
156 
157 // Pattern3======================================
158 // a = Load(para1, u1)
159 // ...
160 // b = Load(para1, u2)
161 // t = make_tuple(x, y, b, z)
162 // u3 = UpdateState(u2, t)
163 //==>
164 // a = Load(para1, u1)
165 // ...
166 // b = Load(para1, u2)
167 // t = make_tuple(x, y, z)
168 // u3 = UpdateState(u2, t)
ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & make_tuple,const AnfNodePtr & load)169 void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple,
170                               const AnfNodePtr &load) {
171   auto &make_tuple_inputs = make_tuple->inputs();
172   std::vector<AnfNodePtr> new_make_tuple_inputs;
173   (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
174                      [load](const AnfNodePtr &input) { return load != input; });
175   const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
176   // Set abstract for the MakeTuple node.
177   abstract::AbstractBasePtrList element_abstracts;
178   (void)std::transform(new_make_tuple_inputs.begin() + 1, new_make_tuple_inputs.end(),
179                        std::back_inserter(element_abstracts),
180                        [](const AnfNodePtr &input) { return input->abstract(); });
181   new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
182   manager->Replace(make_tuple, new_make_tuple);
183 }
184 
ReplaceLoadUser(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & load)185 bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
186   bool change = false;
187   auto load_users = manager->node_users()[load];
188   for (const auto &load_user : load_users) {
189     // Pattern1
190     if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
191       DeleteLoadUserUpdateState(manager, load_user.first);
192       change = true;
193       continue;
194     }
195 
196     if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
197       const auto &make_tuple = load_user.first->cast<CNodePtr>();
198       auto &maketuple_users = manager->node_users()[make_tuple];
199       auto maketuple_as_input_of_update =
200         maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
201       if (!maketuple_as_input_of_update) {
202         continue;
203       }
204       // Pattern2
205       if (make_tuple->size() == 3) {
206         DeleteLoadUserMakeTuple(manager, make_tuple, load);
207         change = true;
208         continue;
209       }
210       // Pattern3
211       if (make_tuple->size() > 3) {
212         ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
213         change = true;
214       }
215     }
216   }
217   return change;
218 }
219 
ReplaceSameGroupLoad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & toposet,const std::vector<size_t> & group)220 bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
221                           const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) {
222   if (group.size() <= 1) {
223     return false;
224   }
225   bool change = false;
226   const auto &main = toposet[group[0]];
227   for (size_t i = 1; i < group.size(); i++) {
228     change = ReplaceLoadUser(manager, fg, toposet[group[i]]);
229     manager->Replace(toposet[group[i]], main);
230   }
231   return change;
232 }
233 
GetFirstMonad(const FuncGraphPtr & fg)234 AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
235   auto &params = fg->parameters();
236   auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
237   auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr &para) { return HasAbstractUMonad(para); });
238   if (iter != end) {
239     return *iter;
240   }
241   auto monad = NewValueNode(kUMonad);
242   monad->set_abstract(kUMonad->ToAbstract());
243   return monad;
244 }
245 
246 // Replace UpdateStates with U for first load.
247 // Covert:
248 // u1 = UpdateState(u, c)
249 // p1 = Load(para1, u1)  // first load for para1
250 // To:
251 // u1 = UpdateState(u, c)
252 // p1 = Load(para1, u')  // u' is first monad in graph or new monad
ReplaceUpdateStateForLoad(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & need_replace_loads)253 bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
254   if (need_replace_loads.size() == 0) {
255     return false;
256   }
257   bool change = false;
258   constexpr size_t second_input_index = 2;
259   auto monad = GetFirstMonad(fg);
260   for (const auto &load_node : need_replace_loads) {
261     if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
262       continue;
263     }
264     auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
265     auto mgr = fg->manager();
266     MS_EXCEPTION_IF_NULL(mgr);
267     // If the u1 only used by Load and one other updatestate, no need to replace u1 by u'.
268     auto &node_users = mgr->node_users()[update_state];
269     constexpr size_t kUserSize = 2;
270     if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState) || node_users.size() == kUserSize) {
271       continue;
272     }
273     mgr->SetEdge(load_node, second_input_index, monad);
274     change = true;
275   }
276   return change;
277 }
278 
279 // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
280 // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
ReplaceAutoMonadNode(const FuncGraphManagerPtr & manager) const281 bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
282   auto changed = false;
283   for (const FuncGraphPtr &fg : manager->func_graphs()) {
284     std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
285     // Record the set of the first load of param which no nodes modify param before the load in toposort.
286     std::vector<AnfNodePtr> need_replace_loads;
287     // Record the param and the toposort id of the unload user of param, they may modify the value of param.
288     MapParamUserIndexs unload_users_record;
289     // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
290     std::vector<size_t> special_op_indexs;
291     std::vector<std::vector<size_t>> load_groups =
292       GenerateLoadGroups(fg, toposet, &need_replace_loads, &unload_users_record, &special_op_indexs);
293     // split group if there is no-load node between two load nodes.
294     std::vector<std::vector<size_t>> need_merge_loads;
295     for (auto &group : load_groups) {
296       auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1);
297       const auto &unload_user_indexs = unload_users_record[load_param];
298       auto groups = SplitGroup(group, unload_user_indexs, special_op_indexs);
299       need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
300     }
301     for (auto &group : need_merge_loads) {
302       bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
303       if (replaced) {
304         changed = true;
305       }
306     }
307     bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
308     if (update_state_replaced) {
309       changed = true;
310     }
311   }
312   return changed;
313 }
314 
315 // Eliminate auto monad node:
316 // From:
317 //    u1 = UpdateState(...);
318 //    xxx = User(u1); // Other users except below Depend.
319 //    output = Depend(output, u1);
320 //    return output;
321 // To:
322 //    u1 = UpdateState(...);
323 //    xxx = User(u1);
324 //    return output;
EliminateAutoMonadNode(const FuncGraphManagerPtr & manager) const325 bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const {
326   auto changed = false;
327   for (const FuncGraphPtr &fg : manager->func_graphs()) {
328     auto output = fg->output();
329     if (output == nullptr) {
330       continue;
331     }
332     if (!IsPrimitiveCNode(output, prim::kPrimDepend)) {
333       continue;
334     }
335     constexpr size_t attach_index = 2;
336     auto attach = output->cast<CNodePtr>()->input(attach_index);
337     if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) {
338       continue;
339     }
340     auto &node_users = manager->node_users();
341     auto iter = node_users.find(attach);
342     if (iter == node_users.end()) {
343       MS_LOG(EXCEPTION) << "No user of node: " << attach->DebugString();
344     }
345     auto &users = iter->second;
346     if (users.size() <= 1) {
347       continue;
348     }
349     constexpr size_t input_index = 1;
350     auto input = output->cast<CNodePtr>()->input(input_index);
351     MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString();
352     fg->set_output(input);
353     changed = true;
354   }
355   MS_LOG(DEBUG) << "Changed: " << changed;
356   return changed;
357 }
358 }  // namespace opt
359 }  // namespace mindspore
360