• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/auto_monad_eliminate.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <string>
22 #include <optional>
23 #include <map>
24 #include <utility>
25 #include <vector>
26 
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "utils/hash_map.h"
30 #include "utils/ordered_map.h"
31 #include "abstract/abstract_value.h"
32 
33 namespace mindspore {
34 namespace opt {
35 namespace {
36 using ParamUserMap = mindspore::HashMap<std::string, std::vector<size_t>>;
37 using LoadGraphMap = OrderedMap<std::string, std::vector<size_t>>;
38 
GetRefKey(const AnfNodePtr & node)39 std::optional<std::string> GetRefKey(const AnfNodePtr &node) {
40   auto abs = node->abstract();
41   if (abs == nullptr) {
42     // Abstract for some Depends node are not proper set, we follow its input.
43     if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
44       return GetRefKey(node->cast<CNodePtr>()->input(1));
45     }
46     // Abstract should be set except UpdateState nodes.
47     if (!IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
48       MS_LOG(WARNING) << "Abstract not set for " << node->DebugString();
49     }
50     return std::nullopt;
51   }
52   auto abs_ref = abs->cast<abstract::AbstractRefPtr>();
53   if (abs_ref == nullptr) {
54     return std::nullopt;
55   }
56   auto ref_key = abs_ref->ref_key_value()->cast<StringImmPtr>();
57   if (ref_key == nullptr) {
58     return std::nullopt;
59   }
60   return ref_key->value();
61 }
62 
HasSideEffect(const CNodePtr & cnode)63 bool HasSideEffect(const CNodePtr &cnode) {
64   const auto &inputs = cnode->inputs();
65   constexpr size_t kRequiredArgs = 2;
66   if (inputs.size() > kRequiredArgs) {
67     return HasAbstractMonad(inputs.back());
68   }
69   return false;
70 }
71 
IsSpecialNode(const CNodePtr & cnode)72 bool IsSpecialNode(const CNodePtr &cnode) {
73   const auto &first_input = cnode->input(0);
74   return IsPrimitiveCNode(first_input, prim::kPrimJ) || IsPrimitiveCNode(first_input, prim::kPrimVmap) ||
75          IsPrimitiveCNode(first_input, prim::kPrimTaylor) || IsPrimitiveCNode(first_input, prim::kPrimShard) ||
76          IsValueNode<FuncGraph>(first_input) || cnode->IsApply(prim::kPrimCall) || cnode->IsApply(prim::kPrimPartial) ||
77          cnode->IsApply(prim::kPrimSwitch) || cnode->IsApply(prim::kPrimSwitchLayer);
78 }
79 
GenerateLoadGroups(const FuncGraphPtr & fg,std::vector<AnfNodePtr> * toposet,std::vector<AnfNodePtr> * need_replace_loads,ParamUserMap * param_users,std::vector<size_t> * special_op_indexes)80 LoadGraphMap GenerateLoadGroups(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *toposet,
81                                 std::vector<AnfNodePtr> *need_replace_loads, ParamUserMap *param_users,
82                                 std::vector<size_t> *special_op_indexes) {
83   LoadGraphMap load_groups;
84   // Record inputs of load and id of load in toposort.
85   // RefKey --> (Monad --> index).
86   std::map<std::string, std::map<AnfNodePtr, size_t>> param_monads;
87   auto mgr = fg->manager();
88   MS_EXCEPTION_IF_NULL(mgr);
89   for (size_t i = 0; i < toposet->size(); i++) {
90     auto cnode = dyn_cast<CNode>((*toposet)[i]);
91     // Exclude free variable node.
92     if (cnode == nullptr || cnode->func_graph() != fg) {
93       continue;
94     }
95     // Handle Load node.
96     if (cnode->IsApply(prim::kPrimLoad)) {
97       auto ref_key = GetRefKey(cnode->input(1));
98       if (!ref_key.has_value()) {
99         MS_LOG(INFO) << "Load without ref key: " << cnode->DebugString();
100         continue;
101       }
102       // Group load nodes by their input ref key.
103       auto &group = load_groups[ref_key.value()];
104       constexpr size_t monad_index = 2;
105       auto monad = cnode->input(monad_index);
106       std::map<AnfNodePtr, size_t> &cur_param_monads = param_monads[ref_key.value()];
107       const auto &iter = cur_param_monads.find(monad);
108       // Remove duplicate load which has the same inputs, otherwise there may be an error in the load grouping.
109       if (iter != cur_param_monads.end()) {
110         auto id = iter->second;
111         auto &first_load = (*toposet)[id];
112         (void)mgr->Replace(cnode, first_load);
113         (*toposet)[i] = first_load;
114         continue;
115       } else {
116         cur_param_monads[monad] = i;
117         (void)group.emplace_back(i);
118       }
119       if (group.size() == 1) {
120         // The first load user of param in toposort, if it can be replace load(param, ud) with load(param, u),
121         // Means there are not nodes which modify param before the load.
122         const bool param_not_used = (param_users->find(ref_key.value()) == param_users->end());
123         const bool can_replace = (param_not_used && special_op_indexes->empty());
124         if (can_replace) {
125           (void)need_replace_loads->emplace_back(cnode);
126         }
127       }
128       continue;
129     }
130     // Record special cnode.
131     if (IsSpecialNode(cnode)) {
132       (void)special_op_indexes->emplace_back(i);
133       continue;
134     }
135     // Record param user in toposort nodes.
136     // We only check side effect cnodes or Depend nodes.
137     if (HasSideEffect(cnode) || cnode->IsApply(prim::kPrimDepend)) {
138       for (size_t n = 1; n < cnode->size(); ++n) {
139         const auto &input = cnode->input(n);
140         auto ref_key = GetRefKey(input);
141         if (ref_key.has_value()) {
142           (void)(*param_users)[ref_key.value()].emplace_back(i);
143         }
144       }
145     }
146   }
147   return load_groups;
148 }
149 
HasIndexBetween(const std::vector<size_t> & indexes,size_t first,size_t second)150 bool HasIndexBetween(const std::vector<size_t> &indexes, size_t first, size_t second) {
151   return std::any_of(indexes.begin(), indexes.end(),
152                      [&first, &second](size_t index) { return index > first && index < second; });
153 }
154 
SplitGroup(const std::vector<size_t> & group,const std::vector<size_t> & param_user_indexes,const std::vector<size_t> & special_op_indexes)155 std::vector<std::vector<size_t>> SplitGroup(const std::vector<size_t> &group,
156                                             const std::vector<size_t> &param_user_indexes,
157                                             const std::vector<size_t> &special_op_indexes) {
158   if (group.size() <= 1) {
159     return {};
160   }
161   size_t cur_load_index = 1;
162   size_t pre_load_index = 0;
163   std::vector<size_t> cur_group = {group[pre_load_index]};
164   std::vector<std::vector<size_t>> split_groups;
165   while (cur_load_index < group.size()) {
166     const auto cur_load = group[cur_load_index];
167     const auto prev_load = group[pre_load_index];
168     // Exist node which is the user of load_param between prev_load and cur_load,
169     // Do not divide into the same group.
170     if (HasIndexBetween(param_user_indexes, prev_load, cur_load) ||
171         HasIndexBetween(special_op_indexes, prev_load, cur_load)) {
172       (void)split_groups.emplace_back(std::move(cur_group));
173     }
174     cur_group.push_back(cur_load);
175     pre_load_index++;
176     cur_load_index++;
177   }
178   // push back the last splited group.
179   split_groups.push_back(cur_group);
180   return split_groups;
181 }
182 
183 // Pattern1======================================
184 // a = Load(para1, u1)
185 // ...
186 // b = Load(para1, u2)
187 // u3 = UpdateState(u2, b)
188 // ==>
189 // delete the UpdateState
DeleteLoadUserUpdateState(const FuncGraphManagerPtr & manager,const AnfNodePtr & load_user)190 void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user) {
191   const auto &update_state_cnode = load_user->cast<CNodePtr>();
192   constexpr size_t monad_index = 1;
193   const auto &monad = update_state_cnode->input(monad_index);
194   (void)manager->Replace(load_user, monad);
195 }
196 
197 // Pattern2======================================
198 // a = Load(para1, u1)
199 // ...
200 // b = Load(para1, u2)
201 // t = make_tuple(x, b)
202 // u3 = UpdateState(u2, t)
203 // ==>
204 // a = Load(para1, u1)
205 // ...
206 // b = Load(para1, u2)
207 // u3 = UpdateState(u2, x)
DeleteLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)208 void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
209   // Initialize the other_input with load in case of all the inputs of the make_tuple is the same load.
210   AnfNodePtr other_input = load;
211   for (size_t i = 1; i < make_tuple->size(); i++) {
212     if (make_tuple->input(i) != load) {
213       other_input = make_tuple->input(i);
214       break;
215     }
216   }
217   MS_EXCEPTION_IF_NULL(other_input);
218   (void)manager->Replace(make_tuple, other_input);
219 }
220 
221 // Pattern3======================================
222 // a = Load(para1, u1)
223 // ...
224 // b = Load(para1, u2)
225 // t = make_tuple(x, y, b, z)
226 // u3 = UpdateState(u2, t)
227 // ==>
228 // a = Load(para1, u1)
229 // ...
230 // b = Load(para1, u2)
231 // t = make_tuple(x, y, z)
232 // u3 = UpdateState(u2, t)
ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr & manager,const CNodePtr & make_tuple,const AnfNodePtr & load)233 void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) {
234   auto &make_tuple_inputs = make_tuple->inputs();
235   std::vector<AnfNodePtr> new_make_tuple_inputs;
236   (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs),
237                      [load](const AnfNodePtr &input) { return load != input; });
238   auto fg = make_tuple->func_graph();
239   MS_EXCEPTION_IF_NULL(fg);
240   const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs);
241   // Set abstract for the MakeTuple node.
242   abstract::AbstractBasePtrList element_abstracts;
243   (void)std::transform(new_make_tuple_inputs.begin() + 1, new_make_tuple_inputs.end(),
244                        std::back_inserter(element_abstracts),
245                        [](const AnfNodePtr &input) { return input->abstract(); });
246   new_make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(element_abstracts));
247   (void)manager->Replace(make_tuple, new_make_tuple);
248 }
249 
ReplaceLoadUser(const FuncGraphManagerPtr & manager,const AnfNodePtr & load)250 bool ReplaceLoadUser(const FuncGraphManagerPtr &manager, const AnfNodePtr &load) {
251   bool change = false;
252   auto load_users = manager->node_users()[load];
253   for (const auto &load_user : load_users) {
254     // Pattern1
255     if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) {
256       DeleteLoadUserUpdateState(manager, load_user.first);
257       change = true;
258       continue;
259     }
260 
261     if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) {
262       const auto &make_tuple = load_user.first->cast<CNodePtr>();
263       auto &maketuple_users = manager->node_users()[make_tuple];
264       auto maketuple_as_input_of_update =
265         maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState);
266       if (!maketuple_as_input_of_update) {
267         continue;
268       }
269       // Pattern2
270       if (make_tuple->size() == 3) {
271         DeleteLoadUserMakeTuple(manager, make_tuple, load);
272         change = true;
273         continue;
274       }
275       // Pattern3
276       if (make_tuple->size() > 3) {
277         ReplaceLoadUserMakeTuple(manager, make_tuple, load);
278         change = true;
279       }
280     }
281   }
282   return change;
283 }
284 
ReplaceSameGroupLoad(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & toposet,const std::vector<size_t> & group)285 bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const std::vector<AnfNodePtr> &toposet,
286                           const std::vector<size_t> &group) {
287   if (group.size() <= 1) {
288     return false;
289   }
290   bool change = false;
291   const auto &main = toposet[group[0]];
292   for (size_t i = 1; i < group.size(); i++) {
293     change = ReplaceLoadUser(manager, toposet[group[i]]);
294     (void)manager->Replace(toposet[group[i]], main);
295   }
296   return change;
297 }
298 
GetFirstMonad(const FuncGraphPtr & fg)299 AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) {
300   auto &params = fg->parameters();
301   auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend();
302   auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr &para) { return HasAbstractUMonad(para); });
303   if (iter != end) {
304     return *iter;
305   }
306   auto monad = NewValueNode(kUMonad);
307   monad->set_abstract(kUMonad->ToAbstract());
308   return monad;
309 }
310 
CheckExistSpecialNode(const AnfNodePtr & update_state,const AnfNodePtr & first_monad)311 bool CheckExistSpecialNode(const AnfNodePtr &update_state, const AnfNodePtr &first_monad) {
312   if (!update_state->isa<CNode>()) {
313     return false;
314   }
315   auto update_state_cnode = update_state->cast<CNodePtr>();
316   MS_EXCEPTION_IF_NULL(update_state_cnode);
317   constexpr size_t monad_input_index = 1;
318   constexpr size_t attach_input_index = 2;
319   auto monad = update_state_cnode->input(monad_input_index);
320   auto attach_node = update_state_cnode->input(attach_input_index);
321   MS_EXCEPTION_IF_NULL(attach_node);
322   if (attach_node->isa<CNode>() && IsSpecialNode(attach_node->cast<CNodePtr>())) {
323     return true;
324   }
325   if (monad == first_monad) {
326     return false;
327   }
328   return CheckExistSpecialNode(monad, first_monad);
329 }
330 
331 // Replace UpdateStates with U for first load.
332 // Covert:
333 // u1 = UpdateState(u, c)
334 // p1 = Load(para1, u1)  // first load for para1, and there are not special node before u1
335 // To:
336 // u1 = UpdateState(u, c)
337 // p1 = Load(para1, u')  // u' is first monad in graph or new monad
ReplaceUpdateStateForLoad(const FuncGraphPtr & fg,const std::vector<AnfNodePtr> & need_replace_loads)338 bool ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &need_replace_loads) {
339   if (need_replace_loads.size() == 0) {
340     return false;
341   }
342   bool change = false;
343   constexpr size_t second_input_index = 2;
344   auto monad = GetFirstMonad(fg);
345   for (const auto &load_node : need_replace_loads) {
346     if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
347       continue;
348     }
349     auto update_state = load_node->cast<CNodePtr>()->input(second_input_index);
350     auto mgr = fg->manager();
351     MS_EXCEPTION_IF_NULL(mgr);
352     // If the u1 only used by Load and one other updatestate, no need to replace u1 by u'.
353     auto &node_users = mgr->node_users()[update_state];
354     constexpr size_t kUserSize = 2;
355     if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState) || node_users.size() == kUserSize) {
356       continue;
357     }
358     // Check whether there is special node before the current load node in the execution sequence.
359     // If exist special node(the node may modify the load parameter), should not replace update_state for the load node.
360     if (CheckExistSpecialNode(update_state, monad)) {
361       continue;
362     }
363     mgr->SetEdge(load_node, second_input_index, monad);
364     change = true;
365   }
366   return change;
367 }
368 }  // namespace
369 
370 // Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... =>
371 // Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,...
ReplaceAutoMonadNode(const FuncGraphManagerPtr & manager) const372 bool AutoMonadEliminator::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const {
373   auto changed = false;
374   for (const FuncGraphPtr &fg : manager->func_graphs()) {
375     std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
376     // Record the set of the first load of param which no nodes modify param before the load in toposort.
377     std::vector<AnfNodePtr> need_replace_loads;
378     // Record the param and the toposort id of the unload user of param, they may modify the value of param.
379     ParamUserMap param_users;
380     // Record the toposort id of special_op(call, partial, switch, switch_layer), they may modify the value of param.
381     std::vector<size_t> special_op_indexes;
382     auto load_groups = GenerateLoadGroups(fg, &toposet, &need_replace_loads, &param_users, &special_op_indexes);
383     // Split group if there is no-load node between two load nodes.
384     std::vector<std::vector<size_t>> need_merge_loads;
385     for (const auto &load_group : load_groups) {
386       auto &ref_key = load_group.first;
387       auto &group = load_group.second;
388       const auto &param_user_indexes = param_users[ref_key];
389       auto groups = SplitGroup(group, param_user_indexes, special_op_indexes);
390       (void)need_merge_loads.insert(need_merge_loads.cend(), groups.cbegin(), groups.cend());
391     }
392     for (auto &group : need_merge_loads) {
393       bool replaced = ReplaceSameGroupLoad(manager, toposet, group);
394       if (replaced) {
395         changed = true;
396       }
397     }
398     bool update_state_replaced = ReplaceUpdateStateForLoad(fg, need_replace_loads);
399     if (update_state_replaced) {
400       changed = true;
401     }
402   }
403   return changed;
404 }
405 
406 // Eliminate auto monad node:
407 // From:
408 //    u1 = UpdateState(...);
409 //    xxx = User(u1); // Other users except below Depend.
410 //    output = Depend(output, u1);
411 //    return output;
412 // To:
413 //    u1 = UpdateState(...);
414 //    xxx = User(u1);
415 //    return output;
EliminateAutoMonadNode(const FuncGraphManagerPtr & manager) const416 bool AutoMonadEliminator::EliminateAutoMonadNode(const FuncGraphManagerPtr &manager) const {
417   auto changed = false;
418   for (const FuncGraphPtr &fg : manager->func_graphs()) {
419     auto output = fg->output();
420     if (output == nullptr) {
421       continue;
422     }
423     if (!IsPrimitiveCNode(output, prim::kPrimDepend)) {
424       continue;
425     }
426     constexpr size_t attach_index = 2;
427     auto attach = output->cast<CNodePtr>()->input(attach_index);
428     if (!IsPrimitiveCNode(attach, prim::kPrimUpdateState)) {
429       continue;
430     }
431     auto &node_users = manager->node_users();
432     auto iter = node_users.find(attach);
433     if (iter == node_users.end()) {
434       MS_LOG(INTERNAL_EXCEPTION) << "No user of node: " << attach->DebugString();
435     }
436     auto &users = iter->second;
437     if (users.size() <= 1) {
438       continue;
439     }
440     constexpr size_t input_index = 1;
441     auto input = output->cast<CNodePtr>()->input(input_index);
442     MS_LOG(DEBUG) << "Change " << output->DebugString() << " -> " << input->DebugString();
443     fg->set_output(input);
444     changed = true;
445   }
446   MS_LOG(DEBUG) << "Changed: " << changed;
447   return changed;
448 }
449 }  // namespace opt
450 }  // namespace mindspore
451