1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/auto_monad_eliminate.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <optional>
23 #include <map>
24 #include <utility>
25 #include <vector>
26
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "utils/hash_map.h"
30 #include "utils/ordered_map.h"
31 #include "abstract/abstract_value.h"
32
33 namespace mindspore {
34 namespace opt {
35 namespace {
36 using ParamUserMap = mindspore::HashMap<std::string, std::vector<size_t>>;
37 using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>;
38
GetRefKey(const AnfNodePtr & node)39 std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
40 auto abs = node->abstract();
41 if (abs == nullptr) {
42 // Abstract for some Depends node are not proper set, we follow its input.
43 if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
44 return GetRefKey(node->cast<CNodePtr>()->input(1));
45 }
46 // Abstract should be set except UpdateState nodes.
47 if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
48 MS_LOG(WARNING) << "Abstract not set for " << node->DebugString();
49 }
50 return std::nullopt;
51 }
52 auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
53 if (abs_ref == nullptr) {
54 return std::nullopt;
55 }
56 auto ref_key = abs_ref->ref_key_value()->cast<StringImmPtr>();
57 if (ref_key == nullptr) {
58 return std::nullopt;
59 }
60 return ref_key->value();
61 }
62
HasSideEffect(const CNodePtr & cnode)63 bool HasSideEffect(const CNodePtr &cnode) {
64 const auto &inputs = cnode->inputs();
65 constexpr size_t kRequiredArgs = 2;
66 if (inputs.size() > kRequiredArgs) {
67 return HasAbstractMonad(inputs.back());
68 }
69 return false;
70 }
71
IsSpecialNode(const CNodePtr & cnode)72 bool IsSpecialNode(const CNodePtr &cnode) {
73 const auto &first_input = cnode->input(0);
74 return IsPrimitiveCNode(first_input, prim::kPrimJ) || IsPrimitiveCNode(first_input, prim::kPrimVmap) ||
75 IsPrimitiveCNode(first_input, prim::kPrimTaylor) || IsPrimitiveCNode(first_input, prim::kPrimShard) ||
76 IsValueNode<FuncGraph>(first_input) || cnode->IsApply(prim::kPrimCall) || cnode->IsApply(prim::kPrimPartial) ||
77 cnode->IsApply(prim::kPrimSwitch) || cnode->IsApply(prim::kPrimSwitchLayer);
78 }
79
GenerateLoadGroups(const FuncGraphPtr & fg,std::vector<AnfNodePtr> * toposet,std::vector<AnfNodePtr> * need_replace_loads,ParamUserMap * param_users,std::vector<size_t> * special_op_indexes)80 LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *toposet,
81 std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
82 std::vector<size_t> *special_op_indexes) {
83 LoadGraphMap load_groups;
84 // Record inputs of load and id of load in toposort.
85 // RefKey --> (Monad --> index).
86 std::map<std::string, std::map<AnfNodePtr, size_t>> param_monads;
87 auto mgr = fg->manager();
88 MS_EXCEPTION_IF_NULL(mgr);
89 for (size_t i = 0; i < toposet->size(); i++) {
90 auto cnode = dyn_cast<CNode>((*toposet)[i]);
91 // Exclude free variable node.
92 if (cnode == nullptr || cnode->func_graph() != fg) {
93 continue;
94 }
95 // Handle Load node.
96 if (cnode->IsApply(prim::kPrimLoad)) {
97 auto ref_key = GetRefKey(cnode->input(1));
98 if (!ref_key.has_value()) {
99 MS_LOG(INFO) << "Load without ref key: " << cnode->DebugString();
100 continue;
101 }
102 // Group load nodes by their input ref key.
103 auto &group = load_groups[ref_key.value()];
104 constexpr size_t monad_index = 2;
105 auto monad = cnode->input(monad_index);
106 std::map<AnfNodePtr, size_t> &cur_param_monads = param_monads[ref_key.value()];
107 const auto &iter = cur_param_monads.find(monad);
108 // Remove duplicate load which has the same inputs, otherwise there may be an error in the load grouping.
109 if (iter != cur_param_monads.end()) {
110 auto id = iter->second;
111 auto &first_load = (*toposet)[id];
112 (void)mgr->Replace(cnode, first_load);
113 (*toposet)[i] = first_load;
114 continue;
115 } else {
116 cur_param_monads[monad] = i;
117 (void)group.emplace_back(i);
118 }
119 if (group.size() == 1) {
120 // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
121 // Means there are not nodes which modify param before the load.
122 const bool param_not_used = (param_users->find(ref_key.value()) == param_users->end());
123 const bool can_replace = (param_not_used && special_op_indexes->empty());
124 if (can_replace) {
125 (void)need_replace_loads->emplace_back(cnode);
126 }
127 }
128 continue;
129 }
130 // Record special cnode.
131 if (IsSpecialNode(cnode)) {
132 (void)special_op_indexes->emplace_back(i);
133 continue;
134 }
135 // Record param user in toposort nodes.
136 // We only check side effect cnodes or Depend nodes.
137 if (HasSideEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
138 for (size_t n = 1; n < cnode->size(); ++n) {
139 const auto &input = cnode->input(n);
140 auto ref_key = GetRefKey(input);
141 if (ref_key.has_value()) {
142 (void)(*param_users)[ref_key.value()].emplace_back(i);
143 }
144 }
145 }
146 }
147 return load_groups;
148 }
149
HasIndexBetween(const std::vector<size_t> & indexes,size_t first,size_t second)150 bool HasIndexBetween(const std::vector<size_t> &indexes, size_t first, size_t second) {
151 return std::any_of(indexes.begin(), indexes.end(),
152 [&first, &second](size_t index) { return index > first && index < second; });
153 }
154
SplitGroup(const std::vector<size_t> & group,const std::vector<size_t> & param_user_indexes,const std::vector<size_t> & special_op_indexes)155 std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
156 const std::vector<size_t> ¶m_user_indexes,
157 const std::vector<size_t> &special_op_indexes) {
158 if (group.size() <= 1) {
159 return {};
160 }
161 size_t cur_load_index = 1;
162 size_t pre_load_index = 0;
163 std::vector<size_t> cur_group = {group[pre_load_index]};
164 std::vector<std::vector<size_t>> split_groups;
165 while (cur_load_index < group.size()) {
166 const auto cur_load = group[cur_load_index];
167 const auto prev_load = group[pre_load_index];
168 // Exist node which is the user of load_param between prev_load and cur_load,
169 // Do not divide into the same group.
170 if (HasIndexBetween(param_user_indexes, prev_load, cur_load) ||
171 HasIndexBetween(special_op_indexes, prev_load, cur_load)) {
172 (void)split_groups.emplace_back(std::move(cur_group));
173 }
174 cur_group.push_back(cur_load);
175 pre_load_index++;
176 cur_load_index++;
177 }
178 // push back the last splited group.
179 split_groups.push_back(cur_group);
180 return split_groups;
181 }
182
183 // Pattern1======================================
184 // a = Load(para1, u1)
185 // ...
186 // b = Load(para1, u2)
187 // u3 = UpdateState(u2, b)
188 // ==>
189 // delete the UpdateState
DeleteLoadUserUpdateState(const FuncGraphManagerPtr & manager,const AnfNodePtr & load_user)190 void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
191 const auto &update_state_cnode = load_user->cast<CNodePtr>();
192 constexpr size_t monad_index = 1;
193 const auto &monad = update_state_cnode->input(monad_index);
194 (void)manager->Replace(load_user, monad);
195 }
196
197 // Pattern2======================================
198 // a = Load(para1, u1)
199 // ...
200 // b = Load(para1, u2)
201 // t = make_tuple(x, b)
202 // u3 = UpdateState(u2, t)
203 // ==>
204 // a = Load(para1, u1)
205 // ...
206 // b = Load(para1, u2)
207 // u3 = UpdateState(u2, x)
DeleteLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)208 void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
209 // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
210 AnfNodePtr other_input = load;
211 for (size_t i = 1; i < make_tuple->size(); i++) {
212 if (make_tuple->input(i) != load) {
213 other_input = make_tuple->input(i);
214 break;
215 }
216 }
217 MS_EXCEPTION_IF_NULL(other_input);
218 (void)manager->Replace(make_tuple, other_input);
219 }
220
221 // Pattern3======================================
222 // a = Load(para1, u1)
223 // ...
224 // b = Load(para1, u2)
225 // t = make_tuple(x, y, b, z)
226 // u3 = UpdateState(u2, t)
227 // ==>
228 // a = Load(para1, u1)
229 // ...
230 // b = Load(para1, u2)
231 // t = make_tuple(x, y, z)
232 // u3 = UpdateState(u2, t)
ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)233 void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
234 auto &make_tuple_inputs = make_tuple->inputs();
235 std::vector<AnfNodePtr> new_make_tuple_inputs;
236 (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
237 [load](const AnfNodePtr &input) { return load != input; });
238 auto fg = make_tuple->func_graph();
239 MS_EXCEPTION_IF_NULL(fg);
240 const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
241 // Set abstract for the MakeTuple node.
242 abstract::AbstractBasePtrList element_abstracts;
243 (void)std::transform(new_make_tuple_inputs.begin() + 1, new_make_tuple_inputs.end(),
244 std::back_inserter(element_abstracts),
245 [](const AnfNodePtr &input) { return input->abstract(); });
246 new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
247 (void)manager->Replace(make_tuple, new_make_tuple);
248 }
249
ReplaceLoadUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & load)250 bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &load) {
251 bool change = false;
252 auto load_users = manager->node_users()[load];
253 for (const auto &load_user : load_users) {
254 // Pattern1
255 if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
256 DeleteLoadUserUpdateState(manager, load_user.first);
257 change = true;
258 continue;
259 }
260
261 if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
262 const auto &make_tuple = load_user.first->cast<CNodePtr>();
263 auto &maketuple_users = manager->node_users()[make_tuple];
264 auto maketuple_as_input_of_update =
265 maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
266 if (!maketuple_as_input_of_update) {
267 continue;
268 }
269 // Pattern2
270 if (make_tuple->size() == 3) {
271 DeleteLoadUserMakeTuple(manager, make_tuple, load);
272 change = true;
273 continue;
274 }
275 // Pattern3
276 if (make_tuple->size() > 3) {
277 ReplaceLoadUserMakeTuple(manager, make_tuple, load);
278 change = true;
279 }
280 }
281 }
282 return change;
283 }
284
ReplaceSameGroupLoad(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & toposet,const std::vector<size_t> & group)285 bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &toposet,
286 const std::vector<size_t> &group) {
287 if (group.size() <= 1) {
288 return false;
289 }
290 bool change = false;
291 const auto &main = toposet[group[0]];
292 for (size_t i = 1; i < group.size(); i++) {
293 change = ReplaceLoadUser(manager, toposet[group[i]]);
294 (void)manager->Replace(toposet[group[i]], main);
295 }
296 return change;
297 }
298
GetFirstMonad(const FuncGraphPtr & fg)299 AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
300 auto ¶ms = fg->parameters();
301 auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
302 auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); });
303 if (iter != end) {
304 return *iter;
305 }
306 auto monad = NewValueNode(kUMonad);
307 monad->set_abstract(kUMonad->ToAbstract());
308 return monad;
309 }
310
CheckExistSpecialNode(const AnfNodePtr & update_state,const AnfNodePtr & first_monad)311 bool CheckExistSpecialNode(const AnfNodePtr &update_state, const AnfNodePtr &first_monad) {
312 if (!update_state->isa<CNode>()) {
313 return false;
314 }
315 auto update_state_cnode = update_state->cast<CNodePtr>();
316 MS_EXCEPTION_IF_NULL(update_state_cnode);
317 constexpr size_t monad_input_index = 1;
318 constexpr size_t attach_input_index = 2;
319 auto monad = update_state_cnode->input(monad_input_index);
320 auto attach_node = update_state_cnode->input(attach_input_index);
321 MS_EXCEPTION_IF_NULL(attach_node);
322 if (attach_node->isa<CNode>() && IsSpecialNode(attach_node->cast<CNodePtr>())) {
323 return true;
324 }
325 if (monad == first_monad) {
326 return false;
327 }
328 return CheckExistSpecialNode(monad, first_monad);
329 }
330
331 // Replace UpdateStates with U for first load.
332 // Covert:
333 // u1 = UpdateState(u, c)
334 // p1 = Load(para1, u1) // first load for para1, and there are not special node before u1
335 // To:
336 // u1 = UpdateState(u, c)
337 // p1 = Load(para1, u') // u' is first monad in graph or new monad
ReplaceUpdateStateForLoad(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & need_replace_loads)338 bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
339 if (need_replace_loads.size() == 0) {
340 return false;
341 }
342 bool change = false;
343 constexpr size_t second_input_index = 2;
344 auto monad = GetFirstMonad(fg);
345 for (const auto &load_node : need_replace_loads) {
346 if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
347 continue;
348 }
349 auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
350 auto mgr = fg->manager();
351 MS_EXCEPTION_IF_NULL(mgr);
352 // If the u1 only used by Load and one other updatestate, no need to replace u1 by u'.
353 auto &node_users = mgr->node_users()[update_state];
354 constexpr size_t kUserSize = 2;
355 if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState) || node_users.size() == kUserSize) {
356 continue;
357 }
358 // Check whether there is special node before the current load node in the execution sequence.
359 // If exist special node(the node may modify the load parameter), should not replace update_state for the load node.
360 if (CheckExistSpecialNode(update_state, monad)) {
361 continue;
362 }
363 mgr->SetEdge(load_node, second_input_index, monad);
364 change = true;
365 }
366 return change;
367 }
368 } // namespace
369
370 // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
371 // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
ReplaceAutoMonadNode(const FuncGraphManagerPtr & manager) const372 bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
373 auto changed = false;
374 for (const FuncGraphPtr &fg : manager->func_graphs()) {
375 std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
376 // Record the set of the first load of param which no nodes modify param before the load in toposort.
377 std::vector<AnfNodePtr> need_replace_loads;
378 // Record the param and the toposort id of the unload user of param, they may modify the value of param.
379 ParamUserMap param_users;
380 // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
381 std::vector<size_t> special_op_indexes;
382 auto load_groups = GenerateLoadGroups(fg, &toposet, &need_replace_loads, ¶m_users, &special_op_indexes);
383 // Split group if there is no-load node between two load nodes.
384 std::vector<std::vector<size_t>> need_merge_loads;
385 for (const auto &load_group : load_groups) {
386 auto &ref_key = load_group.first;
387 auto &group = load_group.second;
388 const auto ¶m_user_indexes = param_users[ref_key];
389 auto groups = SplitGroup(group, param_user_indexes, special_op_indexes);
390 (void)need_merge_loads.insert(need_merge_loads.cend(), groups.cbegin(), groups.cend());
391 }
392 for (auto &group : need_merge_loads) {
393 bool replaced = ReplaceSameGroupLoad(manager, toposet, group);
394 if (replaced) {
395 changed = true;
396 }
397 }
398 bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
399 if (update_state_replaced) {
400 changed = true;
401 }
402 }
403 return changed;
404 }
405
406 // Eliminate auto monad node:
407 // From:
408 // u1 = UpdateState(...);
409 // xxx = User(u1); // Other users except below Depend.
410 // output = Depend(output, u1);
411 // return output;
412 // To:
413 // u1 = UpdateState(...);
414 // xxx = User(u1);
415 // return output;
EliminateAutoMonadNode(const FuncGraphManagerPtr & manager) const416 bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const {
417 auto changed = false;
418 for (const FuncGraphPtr &fg : manager->func_graphs()) {
419 auto output = fg->output();
420 if (output == nullptr) {
421 continue;
422 }
423 if (!IsPrimitiveCNode(output, prim::kPrimDepend)) {
424 continue;
425 }
426 constexpr size_t attach_index = 2;
427 auto attach = output->cast<CNodePtr>()->input(attach_index);
428 if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) {
429 continue;
430 }
431 auto &node_users = manager->node_users();
432 auto iter = node_users.find(attach);
433 if (iter == node_users.end()) {
434 MS_LOG(INTERNAL_EXCEPTION) << "No user of node: " << attach->DebugString();
435 }
436 auto &users = iter->second;
437 if (users.size() <= 1) {
438 continue;
439 }
440 constexpr size_t input_index = 1;
441 auto input = output->cast<CNodePtr>()->input(input_index);
442 MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString();
443 fg->set_output(input);
444 changed = true;
445 }
446 MS_LOG(DEBUG) << "Changed: " << changed;
447 return changed;
448 }
449 } // namespace opt
450 } // namespace mindspore
451