1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/auto_monad_eliminate.h"
18
19 #include <vector>
20 #include <unordered_set>
21 #include <unordered_map>
22 #include <algorithm>
23 #include <memory>
24
25 #include "base/core_ops.h"
26
27 namespace mindspore {
28 namespace opt {
29 using MapParamUserIndexs = std::unordered_map<AnfNodePtr, std::vector<size_t>>;
GenerateLoadGroups(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & toposet,std::vector<AnfNodePtr> * need_replace_loads,MapParamUserIndexs * unload_users_record,std::vector<size_t> * special_op_indexs)30 std::vector<std::vector<size_t>> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &toposet,
31 std::vector<AnfNodePtr> *need_replace_loads,
32 MapParamUserIndexs *unload_users_record,
33 std::vector<size_t> *special_op_indexs) {
34 std::unordered_map<AnfNodePtr, size_t> load_groups_record;
35 std::vector<std::vector<size_t>> load_groups;
36 for (size_t i = 0; i < toposet.size(); i++) {
37 auto &node = toposet[i];
38 auto cnode = node->cast<CNodePtr>();
39 // Exclude free variable node.
40 if (cnode == nullptr || cnode->func_graph() != fg) {
41 continue;
42 }
43 bool is_special_op = IsPrimitiveCNode(cnode, prim::kPrimCall) || IsValueNode<FuncGraph>(cnode->input(0)) ||
44 IsPrimitiveCNode(cnode, prim::kPrimPartial) || IsPrimitiveCNode(cnode, prim::kPrimSwitch) ||
45 IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer);
46 if (is_special_op) {
47 (void)special_op_indexs->emplace_back(i);
48 }
49
50 // Record param user in toposort nodes.
51 if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
52 for (const auto &input : cnode->inputs()) {
53 AnfNodePtr cur_param = nullptr;
54 if (input->isa<Parameter>()) {
55 cur_param = input;
56 } else if (IsPrimitiveCNode(input, prim::kPrimDepend) && input->cast<CNodePtr>()->input(1)->isa<Parameter>()) {
57 cur_param = input->cast<CNodePtr>()->input(1);
58 }
59 if (cur_param != nullptr) {
60 (void)(*unload_users_record)[cur_param].emplace_back(i);
61 }
62 }
63 continue;
64 }
65
66 auto load_param = cnode->input(1);
67 // first time get same input1 of load.
68 if (load_groups_record.find(load_param) == load_groups_record.end()) {
69 load_groups_record[load_param] = load_groups.size();
70 load_groups.push_back({i});
71 // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u)
72 // Means there are not nodes which modify param before the load
73 bool can_replace = (*unload_users_record)[load_param].empty() && special_op_indexs->empty();
74 if (can_replace) {
75 need_replace_loads->emplace_back(cnode);
76 }
77 } else {
78 // not first time get same input1 of load
79 load_groups[load_groups_record[load_param]].push_back(i);
80 }
81 }
82 return load_groups;
83 }
84
SplitGroup(const std::vector<size_t> & group,const std::vector<size_t> & unload_user_indexs,const std::vector<size_t> & special_op_indexs)85 std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
86 const std::vector<size_t> &unload_user_indexs,
87 const std::vector<size_t> &special_op_indexs) {
88 if (group.size() <= 1) {
89 return {};
90 }
91 size_t cur_load_index = 1;
92 size_t pre_load_index = 0;
93 std::vector<size_t> cur_group = {group[pre_load_index]};
94 std::vector<std::vector<size_t>> split_groups;
95 while (cur_load_index < group.size()) {
96 const auto &cur_load = group[cur_load_index];
97 const auto &prev_load = group[pre_load_index];
98 // Exist node which is the user of load_param between prev_load and cur_load,
99 // Do not divide into the same group.
100 const auto param_used_by_other =
101 std::any_of(unload_user_indexs.begin(), unload_user_indexs.end(),
102 [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
103 const auto param_used_by_special_op =
104 std::any_of(special_op_indexs.begin(), special_op_indexs.end(),
105 [&cur_load, &prev_load](size_t index) { return index > prev_load && index < cur_load; });
106 if (param_used_by_other || param_used_by_special_op) {
107 split_groups.push_back(cur_group);
108 cur_group.clear();
109 }
110 cur_group.push_back(cur_load);
111 pre_load_index++;
112 cur_load_index++;
113 }
114 // push back the last splited group.
115 split_groups.push_back(cur_group);
116 return split_groups;
117 }
118
119 // Pattern1======================================
120 // a = Load(para1, u1)
121 // ...
122 // b = Load(para1, u2)
123 // u3 = UpdateState(u2, b)
124 // ==>
125 // delete the UpdateState
DeleteLoadUserUpdateState(const FuncGraphManagerPtr & manager,const AnfNodePtr & load_user)126 void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
127 const auto &update_state_cnode = load_user->cast<CNodePtr>();
128 constexpr size_t monad_index = 1;
129 const auto &monad = update_state_cnode->input(monad_index);
130 (void)manager->Replace(load_user, monad);
131 }
132
133 // Pattern2======================================
134 // a = Load(para1, u1)
135 // ...
136 // b = Load(para1, u2)
137 // t = make_tuple(x, b)
138 // u3 = UpdateState(u2, t)
139 //==>
140 // a = Load(para1, u1)
141 // ...
142 // b = Load(para1, u2)
143 // u3 = UpdateState(u2, x)
DeleteLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)144 void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
145 // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
146 AnfNodePtr other_input = load;
147 for (size_t i = 1; i < make_tuple->size(); i++) {
148 if (make_tuple->input(i) != load) {
149 other_input = make_tuple->input(i);
150 break;
151 }
152 }
153 MS_EXCEPTION_IF_NULL(other_input);
154 manager->Replace(make_tuple, other_input);
155 }
156
157 // Pattern3======================================
158 // a = Load(para1, u1)
159 // ...
160 // b = Load(para1, u2)
161 // t = make_tuple(x, y, b, z)
162 // u3 = UpdateState(u2, t)
163 //==>
164 // a = Load(para1, u1)
165 // ...
166 // b = Load(para1, u2)
167 // t = make_tuple(x, y, z)
168 // u3 = UpdateState(u2, t)
ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const CNodePtr & make_tuple,const AnfNodePtr & load)169 void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple,
170 const AnfNodePtr &load) {
171 auto &make_tuple_inputs = make_tuple->inputs();
172 std::vector<AnfNodePtr> new_make_tuple_inputs;
173 (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
174 [load](const AnfNodePtr &input) { return load != input; });
175 const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
176 // Set abstract for the MakeTuple node.
177 abstract::AbstractBasePtrList element_abstracts;
178 (void)std::transform(new_make_tuple_inputs.begin() + 1, new_make_tuple_inputs.end(),
179 std::back_inserter(element_abstracts),
180 [](const AnfNodePtr &input) { return input->abstract(); });
181 new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
182 manager->Replace(make_tuple, new_make_tuple);
183 }
184
ReplaceLoadUser(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const AnfNodePtr & load)185 bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) {
186 bool change = false;
187 auto load_users = manager->node_users()[load];
188 for (const auto &load_user : load_users) {
189 // Pattern1
190 if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
191 DeleteLoadUserUpdateState(manager, load_user.first);
192 change = true;
193 continue;
194 }
195
196 if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
197 const auto &make_tuple = load_user.first->cast<CNodePtr>();
198 auto &maketuple_users = manager->node_users()[make_tuple];
199 auto maketuple_as_input_of_update =
200 maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
201 if (!maketuple_as_input_of_update) {
202 continue;
203 }
204 // Pattern2
205 if (make_tuple->size() == 3) {
206 DeleteLoadUserMakeTuple(manager, make_tuple, load);
207 change = true;
208 continue;
209 }
210 // Pattern3
211 if (make_tuple->size() > 3) {
212 ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load);
213 change = true;
214 }
215 }
216 }
217 return change;
218 }
219
ReplaceSameGroupLoad(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & toposet,const std::vector<size_t> & group)220 bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg,
221 const std::vector<AnfNodePtr> &toposet, const std::vector<size_t> &group) {
222 if (group.size() <= 1) {
223 return false;
224 }
225 bool change = false;
226 const auto &main = toposet[group[0]];
227 for (size_t i = 1; i < group.size(); i++) {
228 change = ReplaceLoadUser(manager, fg, toposet[group[i]]);
229 manager->Replace(toposet[group[i]], main);
230 }
231 return change;
232 }
233
GetFirstMonad(const FuncGraphPtr & fg)234 AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
235 auto ¶ms = fg->parameters();
236 auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
237 auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); });
238 if (iter != end) {
239 return *iter;
240 }
241 auto monad = NewValueNode(kUMonad);
242 monad->set_abstract(kUMonad->ToAbstract());
243 return monad;
244 }
245
246 // Replace UpdateStates with U for first load.
247 // Covert:
248 // u1 = UpdateState(u, c)
249 // p1 = Load(para1, u1) // first load for para1
250 // To:
251 // u1 = UpdateState(u, c)
252 // p1 = Load(para1, u') // u' is first monad in graph or new monad
ReplaceUpdateStateForLoad(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & need_replace_loads)253 bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
254 if (need_replace_loads.size() == 0) {
255 return false;
256 }
257 bool change = false;
258 constexpr size_t second_input_index = 2;
259 auto monad = GetFirstMonad(fg);
260 for (const auto &load_node : need_replace_loads) {
261 if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
262 continue;
263 }
264 auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
265 auto mgr = fg->manager();
266 MS_EXCEPTION_IF_NULL(mgr);
267 // If the u1 only used by Load and one other updatestate, no need to replace u1 by u'.
268 auto &node_users = mgr->node_users()[update_state];
269 constexpr size_t kUserSize = 2;
270 if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState) || node_users.size() == kUserSize) {
271 continue;
272 }
273 mgr->SetEdge(load_node, second_input_index, monad);
274 change = true;
275 }
276 return change;
277 }
278
279 // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
280 // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
ReplaceAutoMonadNode(const FuncGraphManagerPtr & manager) const281 bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
282 auto changed = false;
283 for (const FuncGraphPtr &fg : manager->func_graphs()) {
284 std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
285 // Record the set of the first load of param which no nodes modify param before the load in toposort.
286 std::vector<AnfNodePtr> need_replace_loads;
287 // Record the param and the toposort id of the unload user of param, they may modify the value of param.
288 MapParamUserIndexs unload_users_record;
289 // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
290 std::vector<size_t> special_op_indexs;
291 std::vector<std::vector<size_t>> load_groups =
292 GenerateLoadGroups(fg, toposet, &need_replace_loads, &unload_users_record, &special_op_indexs);
293 // split group if there is no-load node between two load nodes.
294 std::vector<std::vector<size_t>> need_merge_loads;
295 for (auto &group : load_groups) {
296 auto load_param = toposet[group.back()]->cast<CNodePtr>()->input(1);
297 const auto &unload_user_indexs = unload_users_record[load_param];
298 auto groups = SplitGroup(group, unload_user_indexs, special_op_indexs);
299 need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end());
300 }
301 for (auto &group : need_merge_loads) {
302 bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group);
303 if (replaced) {
304 changed = true;
305 }
306 }
307 bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
308 if (update_state_replaced) {
309 changed = true;
310 }
311 }
312 return changed;
313 }
314
315 // Eliminate auto monad node:
316 // From:
317 // u1 = UpdateState(...);
318 // xxx = User(u1); // Other users except below Depend.
319 // output = Depend(output, u1);
320 // return output;
321 // To:
322 // u1 = UpdateState(...);
323 // xxx = User(u1);
324 // return output;
EliminateAutoMonadNode(const FuncGraphManagerPtr & manager) const325 bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const {
326 auto changed = false;
327 for (const FuncGraphPtr &fg : manager->func_graphs()) {
328 auto output = fg->output();
329 if (output == nullptr) {
330 continue;
331 }
332 if (!IsPrimitiveCNode(output, prim::kPrimDepend)) {
333 continue;
334 }
335 constexpr size_t attach_index = 2;
336 auto attach = output->cast<CNodePtr>()->input(attach_index);
337 if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) {
338 continue;
339 }
340 auto &node_users = manager->node_users();
341 auto iter = node_users.find(attach);
342 if (iter == node_users.end()) {
343 MS_LOG(EXCEPTION) << "No user of node: " << attach->DebugString();
344 }
345 auto &users = iter->second;
346 if (users.size() <= 1) {
347 continue;
348 }
349 constexpr size_t input_index = 1;
350 auto input = output->cast<CNodePtr>()->input(input_index);
351 MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString();
352 fg->set_output(input);
353 changed = true;
354 }
355 MS_LOG(DEBUG) << "Changed: " << changed;
356 return changed;
357 }
358 } // namespace opt
359 } // namespace mindspore
360