• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "include/common/utils/cse.h"
20 
21 #include <vector>
22 #include <set>
23 
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "utils/hash_map.h"
27 #include "abstract/abstract_function.h"
28 #include "utils/flags.h"
29 #include "include/common/utils/utils.h"
30 #include "utils/anf_utils.h"
31 
32 namespace mindspore {
33 /* namespace to support opt */
34 namespace opt {
35 using mindspore::abstract::AbstractBase;
36 using mindspore::abstract::AbstractFunction;
37 using mindspore::abstract::AbstractFunctionPtr;
38 
WithRecomputedScope(const AnfNodePtr & node)39 bool WithRecomputedScope(const AnfNodePtr &node) {
40   MS_EXCEPTION_IF_NULL(node);
41   if (!node->isa<CNode>()) {
42     return false;
43   }
44   auto full_name_with_scope = node->fullname_with_scope();
45   return full_name_with_scope.find(kAttrRecompute) == 0;
46 }
47 
IsSetRecomputed(const CNodePtr & a,const CNodePtr & b)48 bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) {
49   return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) ||
50          (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute));
51 }
52 
IsHiddenSideEffectNode(const AnfNodePtr & node)53 bool IsHiddenSideEffectNode(const AnfNodePtr &node) {
54   auto prim = GetCNodePrimitive(node);
55   if (prim == nullptr) {
56     return false;
57   }
58   return prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
59 }
60 
UpdateDebugInfoAndDumpFlag(const AnfNodePtr & main,const AnfNodePtr & node)61 void UpdateDebugInfoAndDumpFlag(const AnfNodePtr &main, const AnfNodePtr &node) {
62   if (main == nullptr || !main->isa<CNode>()) {
63     return;
64   }
65   if (AnfUtils::GetDumpFlag(node) && !AnfUtils::GetDumpFlag(main)) {
66     AnfUtils::SetDumpFlag(main);
67   }
68   auto main_cnode = main->cast<CNodePtr>();
69   main_cnode->AddFusedDebugInfo(node);
70 }
71 
AbsOf(const AnfNodePtr & node,bool ignore_fg_abs_tracking_id)72 BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
73   MS_EXCEPTION_IF_NULL(node);
74   auto node_abs = node->abstract();
75   // In testcase: TestOptOpt.CSE, node->abstract() is null.
76   if (node_abs == nullptr) {
77     return kValueAny;
78   }
79   if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
80     // Ignore the tracking_id and prim pointer hash.
81     auto prim_abs = node_abs->cast_ptr<abstract::PrimitiveAbstractClosure>();
82     return prim_abs->prim();
83   } else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
84     // Ignore the tracking_id.
85     return node_abs->cast_ptr<abstract::AbstractFunction>()->CopyWithoutTrackingId();
86   }
87   return node_abs;
88 }
89 
BuildOrderGroupForOneGraph(const FuncGraphPtr & fg)90 bool CSE::BuildOrderGroupForOneGraph(const FuncGraphPtr &fg) {
91   MS_EXCEPTION_IF_NULL(fg);
92   std::vector<std::size_t> order_group;
93   mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>> groups;
94   mindspore::HashMap<AnfNodePtr, std::size_t> hashes;
95 
96   std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
97   for (const auto &node : toposet) {
98     MS_EXCEPTION_IF_NULL(node);
99     if (hashes.find(node) != hashes.end()) {
100       continue;
101     }
102     if (IsHiddenSideEffectNode(node) && node->func_graph() != nullptr) {
103       MS_LOG(DEBUG) << "Add hidden func graph:" << node->func_graph()->ToString();
104       (void)hidden_side_effect_func_graphs_.insert(node->func_graph());
105     }
106     std::size_t h = 0;
107     if (node->isa<ValueNode>()) {
108       auto prim = GetValueNode<PrimitivePtr>(node);
109       if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
110         continue;
111       }
112       ValueNodePtr value_node = node->cast<ValueNodePtr>();
113       auto value = value_node->value();
114       MS_EXCEPTION_IF_NULL(value);
115       h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
116     } else if (node->isa<CNode>()) {
117       auto cnode = node->cast<CNodePtr>();
118       auto &inputs = cnode->inputs();
119       size_t init = 0;
120       h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
121         return hash_combine(hash, hashes[node_in]);
122       });
123     } else if (node->isa<Parameter>()) {
124       h = node->hash();
125     } else {
126       MS_LOG(ERROR) << "Unknown node type";
127     }
128 
129     hashes[node] = h;
130     if (groups.find(h) == groups.end()) {
131       std::vector<AnfNodePtr> innervec({node});
132       groups[h] = innervec;
133       order_group.emplace_back(h);
134     } else {
135       groups[h].push_back(node);
136     }
137   }
138   return CalReplaceNodes(order_group, &groups);
139 }
140 
DoReplace(const FuncGraphManagerPtr & manager)141 void CSE::DoReplace(const FuncGraphManagerPtr &manager) {
142   // if A is a hidden_side_effect node, then A's user B can't be replaced by main, then B's user C can't be replaced by
143   // main.
144   auto tr = manager->Transact();
145   HashSet<AnfNodePtr> cannot_replace_nodes;
146   for (const auto &[node, main] : replicated_nodes_) {
147     bool main_input_cannot_replace = false;
148     if (main->isa<CNode>()) {
149       auto c_main = main->cast<CNodePtr>();
150       const auto &c_main_inputs = c_main->inputs();
151       auto input_can_not_replace = [&cannot_replace_nodes](const AnfNodePtr &node) {
152         return cannot_replace_nodes.find(node) != cannot_replace_nodes.cend();
153       };
154       main_input_cannot_replace = std::any_of(c_main_inputs.cbegin(), c_main_inputs.cend(), input_can_not_replace);
155     }
156     if (HasHiddenSideEffect(main) || main_input_cannot_replace) {
157       (void)cannot_replace_nodes.insert(main);
158       continue;
159     }
160     // We don't merge primitive cnodes with random effect.
161     MS_LOG(DEBUG) << "CSE replace, node:" << node->DebugString() << ", main:" << main->DebugString();
162     tr.Replace(node, main);
163   }
164   tr.Commit();
165 }
166 
BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager)167 bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) {
168   bool changed = false;
169   for (const auto &fg : manager->func_graphs()) {
170     changed = BuildOrderGroupForOneGraph(fg) || changed;
171   }
172   DoReplace(manager);
173   return changed;
174 }
175 
176 // Check whether is a func graph call node and func graph has hidden side effect node.
IsHiddenSideEffectCall(const AnfNodePtr & node)177 bool CSE::IsHiddenSideEffectCall(const AnfNodePtr &node) {
178   if (!node->isa<CNode>()) {
179     return false;
180   }
181   auto cnode = node->cast<CNodePtr>();
182   // Check weather it is a func graph call.
183   if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
184     return false;
185   }
186   // If it is a func graph call node, get all graphs  from abstract.
187   auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
188   auto is_hidden_side_effect_graph = [this](const FuncGraphPtr &fg) -> bool {
189     return hidden_side_effect_func_graphs_.find(fg) != hidden_side_effect_func_graphs_.end();
190   };
191   return std::any_of(func_graphs.cbegin(), func_graphs.cend(), is_hidden_side_effect_graph);
192 }
193 
HasHiddenSideEffect(const AnfNodePtr & node)194 bool CSE::HasHiddenSideEffect(const AnfNodePtr &node) {
195   if (IsHiddenSideEffectNode(node)) {
196     return true;
197   }
198   if (IsHiddenSideEffectCall(node)) {
199     return true;
200   }
201   return false;
202 }
203 
GetReplicatedNode(const AnfNodePtr & node) const204 AnfNodePtr CSE::GetReplicatedNode(const AnfNodePtr &node) const {
205   MS_EXCEPTION_IF_NULL(node);
206   auto it = replicated_nodes_.find(node);
207   if (it != replicated_nodes_.cend()) {
208     return it->second;
209   }
210   return node;
211 }
212 
AddReplicatedNode(const AnfNodePtr & node,const AnfNodePtr & main)213 void CSE::AddReplicatedNode(const AnfNodePtr &node, const AnfNodePtr &main) {
214   MS_EXCEPTION_IF_NULL(node);
215   MS_EXCEPTION_IF_NULL(main);
216   if (node == main) {
217     MS_LOG(WARNING) << "Can't replace node by itself, node:" << node->DebugString();
218     return;
219   }
220   (void)replicated_nodes_.emplace(node, main);
221 }
222 
CheckReplace(const AnfNodePtr & main,const AnfNodePtr & node)223 bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) {
224   MS_EXCEPTION_IF_NULL(main);
225   MS_EXCEPTION_IF_NULL(node);
226   if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
227     auto main_value = GetValueNode(main);
228     auto node_value = GetValueNode(node);
229     return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
230   } else if (main->isa<CNode>() && node->isa<CNode>()) {
231     auto c_main = main->cast<CNodePtr>();
232     auto c_node = node->cast<CNodePtr>();
233     // Not do cse for the node set recompute before the recompute pass.
234     if (IsSetRecomputed(c_main, c_node)) {
235       return false;
236     }
237     // Can not merge J because the J user size should be 1.
238     if (IsPrimitiveCNode(c_main, prim::kPrimJ) || IsPrimitiveCNode(c_main, prim::kPrimReceive)) {
239       return false;
240     }
241     if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
242       return false;
243     }
244     const auto &inputs1 = c_main->inputs();
245     const auto &inputs2 = c_node->inputs();
246     if (inputs1.size() != inputs2.size()) {
247       return false;
248     }
249     // Check inputs, all inputs should equal.
250     for (size_t i = 0; i < inputs1.size(); i++) {
251       auto input1 = GetReplicatedNode(inputs1[i]);
252       auto input2 = GetReplicatedNode(inputs2[i]);
253       MS_EXCEPTION_IF_NULL(input1);
254       MS_EXCEPTION_IF_NULL(input2);
255       if ((input1 == input2) || (*input1 == *input2)) {
256         continue;
257       }
258       // Handle the case of two different Tensor, but with the same value.
259       if (IsValueNode<tensor::Tensor>(input1) && IsValueNode<tensor::Tensor>(input2)) {
260         auto tensor1 = GetValueNode<tensor::TensorPtr>(input1);
261         auto tensor2 = GetValueNode<tensor::TensorPtr>(input2);
262         if (tensor1->ValueEqual(*tensor2)) {
263           continue;
264         }
265       }
266       return false;
267     }
268     return true;
269   }
270   // a parameter node.
271   return false;
272 }
273 
CalReplaceNodes(const std::vector<std::size_t> & order_group,mindspore::HashMap<std::size_t,std::vector<AnfNodePtr>> * groups)274 bool CSE::CalReplaceNodes(const std::vector<std::size_t> &order_group,
275                           mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>> *groups) {
276   bool changes = false;
277   std::set<size_t> clear_set;
278   for (auto &h : order_group) {
279     std::vector<AnfNodePtr> &group = (*groups)[h];
280     // If there are more than 2 node in that group, they may be same common expression can be eliminated.
281     if (group.size() > 1) {
282       for (size_t k = 0; k < group.size() - 1; k++) {
283         AnfNodePtr main = group[k];
284         MS_EXCEPTION_IF_NULL(main);
285 
286         // When all node in group has been replaced
287         // or a valuenode node, skip compare in group
288         if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa<ValueNode>())) {
289           break;
290         }
291 
292         // skip node has been replaced
293         if (clear_set.find(k) != clear_set.end()) {
294           continue;
295         }
296 
297         // Compare with rest elements in this group.
298         for (size_t i = k + 1; i < group.size(); i++) {
299           auto node = group[i];
300           MS_EXCEPTION_IF_NULL(node);
301 
302           if (clear_set.find(i) != clear_set.end()) {
303             continue;
304           }
305           if (main->func_graph() != node->func_graph()) {
306             continue;
307           }
308           if (CheckReplace(node, main)) {
309             changes = true;
310             UpdateDebugInfoAndDumpFlag(main, node);
311             MS_LOG(DEBUG) << "Add replicated_nodes_, node:" << node->DebugString() << ", main:" << main->DebugString();
312             AddReplicatedNode(node, main);
313             (void)clear_set.insert(i);
314           }
315         }
316       }
317       clear_set.clear();
318     }
319   }
320   return changes;
321 }
322 
Init()323 void CSE::Init() {
324   hidden_side_effect_func_graphs_.clear();
325   replicated_nodes_.clear();
326 }
327 
Cse(const FuncGraphPtr root,const FuncGraphManagerPtr manager)328 bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
329   MS_EXCEPTION_IF_NULL(manager);
330   Init();
331   manager->AddFuncGraph(root);
332   return BuildOrderGroupAndDoReplace(manager);
333 }
334 }  // namespace opt
335 }  // namespace mindspore
336