1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "include/common/utils/cse.h"
20
21 #include <vector>
22 #include <set>
23
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/anf.h"
26 #include "utils/hash_map.h"
27 #include "abstract/abstract_function.h"
28 #include "utils/flags.h"
29 #include "include/common/utils/utils.h"
30 #include "utils/anf_utils.h"
31
32 namespace mindspore {
33 /* namespace to support opt */
34 namespace opt {
35 using mindspore::abstract::AbstractBase;
36 using mindspore::abstract::AbstractFunction;
37 using mindspore::abstract::AbstractFunctionPtr;
38
WithRecomputedScope(const AnfNodePtr & node)39 bool WithRecomputedScope(const AnfNodePtr &node) {
40 MS_EXCEPTION_IF_NULL(node);
41 if (!node->isa<CNode>()) {
42 return false;
43 }
44 auto full_name_with_scope = node->fullname_with_scope();
45 return full_name_with_scope.find(kAttrRecompute) == 0;
46 }
47
IsSetRecomputed(const CNodePtr & a,const CNodePtr & b)48 bool IsSetRecomputed(const CNodePtr &a, const CNodePtr &b) {
49 return (WithRecomputedScope(a) && !a->HasAttr(kAttrNeedCseAfterRecompute)) ||
50 (WithRecomputedScope(b) && !b->HasAttr(kAttrNeedCseAfterRecompute));
51 }
52
IsHiddenSideEffectNode(const AnfNodePtr & node)53 bool IsHiddenSideEffectNode(const AnfNodePtr &node) {
54 auto prim = GetCNodePrimitive(node);
55 if (prim == nullptr) {
56 return false;
57 }
58 return prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
59 }
60
UpdateDebugInfoAndDumpFlag(const AnfNodePtr & main,const AnfNodePtr & node)61 void UpdateDebugInfoAndDumpFlag(const AnfNodePtr &main, const AnfNodePtr &node) {
62 if (main == nullptr || !main->isa<CNode>()) {
63 return;
64 }
65 if (AnfUtils::GetDumpFlag(node) && !AnfUtils::GetDumpFlag(main)) {
66 AnfUtils::SetDumpFlag(main);
67 }
68 auto main_cnode = main->cast<CNodePtr>();
69 main_cnode->AddFusedDebugInfo(node);
70 }
71
AbsOf(const AnfNodePtr & node,bool ignore_fg_abs_tracking_id)72 BasePtr AbsOf(const AnfNodePtr &node, bool ignore_fg_abs_tracking_id) {
73 MS_EXCEPTION_IF_NULL(node);
74 auto node_abs = node->abstract();
75 // In testcase: TestOptOpt.CSE, node->abstract() is null.
76 if (node_abs == nullptr) {
77 return kValueAny;
78 }
79 if (node_abs->isa<abstract::PrimitiveAbstractClosure>()) {
80 // Ignore the tracking_id and prim pointer hash.
81 auto prim_abs = node_abs->cast_ptr<abstract::PrimitiveAbstractClosure>();
82 return prim_abs->prim();
83 } else if (ignore_fg_abs_tracking_id && node_abs->isa<abstract::FuncGraphAbstractClosure>()) {
84 // Ignore the tracking_id.
85 return node_abs->cast_ptr<abstract::AbstractFunction>()->CopyWithoutTrackingId();
86 }
87 return node_abs;
88 }
89
BuildOrderGroupForOneGraph(const FuncGraphPtr & fg)90 bool CSE::BuildOrderGroupForOneGraph(const FuncGraphPtr &fg) {
91 MS_EXCEPTION_IF_NULL(fg);
92 std::vector<std::size_t> order_group;
93 mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>> groups;
94 mindspore::HashMap<AnfNodePtr, std::size_t> hashes;
95
96 std::vector<AnfNodePtr> toposet = TopoSort(fg->get_return());
97 for (const auto &node : toposet) {
98 MS_EXCEPTION_IF_NULL(node);
99 if (hashes.find(node) != hashes.end()) {
100 continue;
101 }
102 if (IsHiddenSideEffectNode(node) && node->func_graph() != nullptr) {
103 MS_LOG(DEBUG) << "Add hidden func graph:" << node->func_graph()->ToString();
104 (void)hidden_side_effect_func_graphs_.insert(node->func_graph());
105 }
106 std::size_t h = 0;
107 if (node->isa<ValueNode>()) {
108 auto prim = GetValueNode<PrimitivePtr>(node);
109 if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
110 continue;
111 }
112 ValueNodePtr value_node = node->cast<ValueNodePtr>();
113 auto value = value_node->value();
114 MS_EXCEPTION_IF_NULL(value);
115 h = hash_combine(value->hash(), (AbsOf(value_node, true)->hash()));
116 } else if (node->isa<CNode>()) {
117 auto cnode = node->cast<CNodePtr>();
118 auto &inputs = cnode->inputs();
119 size_t init = 0;
120 h = std::accumulate(inputs.begin(), inputs.end(), init, [&hashes](std::size_t hash, const AnfNodePtr &node_in) {
121 return hash_combine(hash, hashes[node_in]);
122 });
123 } else if (node->isa<Parameter>()) {
124 h = node->hash();
125 } else {
126 MS_LOG(ERROR) << "Unknown node type";
127 }
128
129 hashes[node] = h;
130 if (groups.find(h) == groups.end()) {
131 std::vector<AnfNodePtr> innervec({node});
132 groups[h] = innervec;
133 order_group.emplace_back(h);
134 } else {
135 groups[h].push_back(node);
136 }
137 }
138 return CalReplaceNodes(order_group, &groups);
139 }
140
DoReplace(const FuncGraphManagerPtr & manager)141 void CSE::DoReplace(const FuncGraphManagerPtr &manager) {
142 // if A is a hidden_side_effect node, then A's user B can't be replaced by main, then B's user C can't be replaced by
143 // main.
144 auto tr = manager->Transact();
145 HashSet<AnfNodePtr> cannot_replace_nodes;
146 for (const auto &[node, main] : replicated_nodes_) {
147 bool main_input_cannot_replace = false;
148 if (main->isa<CNode>()) {
149 auto c_main = main->cast<CNodePtr>();
150 const auto &c_main_inputs = c_main->inputs();
151 auto input_can_not_replace = [&cannot_replace_nodes](const AnfNodePtr &node) {
152 return cannot_replace_nodes.find(node) != cannot_replace_nodes.cend();
153 };
154 main_input_cannot_replace = std::any_of(c_main_inputs.cbegin(), c_main_inputs.cend(), input_can_not_replace);
155 }
156 if (HasHiddenSideEffect(main) || main_input_cannot_replace) {
157 (void)cannot_replace_nodes.insert(main);
158 continue;
159 }
160 // We don't merge primitive cnodes with random effect.
161 MS_LOG(DEBUG) << "CSE replace, node:" << node->DebugString() << ", main:" << main->DebugString();
162 tr.Replace(node, main);
163 }
164 tr.Commit();
165 }
166
BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager)167 bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) {
168 bool changed = false;
169 for (const auto &fg : manager->func_graphs()) {
170 changed = BuildOrderGroupForOneGraph(fg) || changed;
171 }
172 DoReplace(manager);
173 return changed;
174 }
175
176 // Check whether is a func graph call node and func graph has hidden side effect node.
IsHiddenSideEffectCall(const AnfNodePtr & node)177 bool CSE::IsHiddenSideEffectCall(const AnfNodePtr &node) {
178 if (!node->isa<CNode>()) {
179 return false;
180 }
181 auto cnode = node->cast<CNodePtr>();
182 // Check weather it is a func graph call.
183 if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
184 return false;
185 }
186 // If it is a func graph call node, get all graphs from abstract.
187 auto func_graphs = abstract::GetFuncGraphsFromCallNode(cnode);
188 auto is_hidden_side_effect_graph = [this](const FuncGraphPtr &fg) -> bool {
189 return hidden_side_effect_func_graphs_.find(fg) != hidden_side_effect_func_graphs_.end();
190 };
191 return std::any_of(func_graphs.cbegin(), func_graphs.cend(), is_hidden_side_effect_graph);
192 }
193
HasHiddenSideEffect(const AnfNodePtr & node)194 bool CSE::HasHiddenSideEffect(const AnfNodePtr &node) {
195 if (IsHiddenSideEffectNode(node)) {
196 return true;
197 }
198 if (IsHiddenSideEffectCall(node)) {
199 return true;
200 }
201 return false;
202 }
203
GetReplicatedNode(const AnfNodePtr & node) const204 AnfNodePtr CSE::GetReplicatedNode(const AnfNodePtr &node) const {
205 MS_EXCEPTION_IF_NULL(node);
206 auto it = replicated_nodes_.find(node);
207 if (it != replicated_nodes_.cend()) {
208 return it->second;
209 }
210 return node;
211 }
212
AddReplicatedNode(const AnfNodePtr & node,const AnfNodePtr & main)213 void CSE::AddReplicatedNode(const AnfNodePtr &node, const AnfNodePtr &main) {
214 MS_EXCEPTION_IF_NULL(node);
215 MS_EXCEPTION_IF_NULL(main);
216 if (node == main) {
217 MS_LOG(WARNING) << "Can't replace node by itself, node:" << node->DebugString();
218 return;
219 }
220 (void)replicated_nodes_.emplace(node, main);
221 }
222
CheckReplace(const AnfNodePtr & main,const AnfNodePtr & node)223 bool CSE::CheckReplace(const AnfNodePtr &main, const AnfNodePtr &node) {
224 MS_EXCEPTION_IF_NULL(main);
225 MS_EXCEPTION_IF_NULL(node);
226 if (main->isa<ValueNode>() && node->isa<ValueNode>()) {
227 auto main_value = GetValueNode(main);
228 auto node_value = GetValueNode(node);
229 return (AbsOf(main, true) == AbsOf(node, true)) && (*main_value == *node_value);
230 } else if (main->isa<CNode>() && node->isa<CNode>()) {
231 auto c_main = main->cast<CNodePtr>();
232 auto c_node = node->cast<CNodePtr>();
233 // Not do cse for the node set recompute before the recompute pass.
234 if (IsSetRecomputed(c_main, c_node)) {
235 return false;
236 }
237 // Can not merge J because the J user size should be 1.
238 if (IsPrimitiveCNode(c_main, prim::kPrimJ) || IsPrimitiveCNode(c_main, prim::kPrimReceive)) {
239 return false;
240 }
241 if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
242 return false;
243 }
244 const auto &inputs1 = c_main->inputs();
245 const auto &inputs2 = c_node->inputs();
246 if (inputs1.size() != inputs2.size()) {
247 return false;
248 }
249 // Check inputs, all inputs should equal.
250 for (size_t i = 0; i < inputs1.size(); i++) {
251 auto input1 = GetReplicatedNode(inputs1[i]);
252 auto input2 = GetReplicatedNode(inputs2[i]);
253 MS_EXCEPTION_IF_NULL(input1);
254 MS_EXCEPTION_IF_NULL(input2);
255 if ((input1 == input2) || (*input1 == *input2)) {
256 continue;
257 }
258 // Handle the case of two different Tensor, but with the same value.
259 if (IsValueNode<tensor::Tensor>(input1) && IsValueNode<tensor::Tensor>(input2)) {
260 auto tensor1 = GetValueNode<tensor::TensorPtr>(input1);
261 auto tensor2 = GetValueNode<tensor::TensorPtr>(input2);
262 if (tensor1->ValueEqual(*tensor2)) {
263 continue;
264 }
265 }
266 return false;
267 }
268 return true;
269 }
270 // a parameter node.
271 return false;
272 }
273
CalReplaceNodes(const std::vector<std::size_t> & order_group,mindspore::HashMap<std::size_t,std::vector<AnfNodePtr>> * groups)274 bool CSE::CalReplaceNodes(const std::vector<std::size_t> &order_group,
275 mindspore::HashMap<std::size_t, std::vector<AnfNodePtr>> *groups) {
276 bool changes = false;
277 std::set<size_t> clear_set;
278 for (auto &h : order_group) {
279 std::vector<AnfNodePtr> &group = (*groups)[h];
280 // If there are more than 2 node in that group, they may be same common expression can be eliminated.
281 if (group.size() > 1) {
282 for (size_t k = 0; k < group.size() - 1; k++) {
283 AnfNodePtr main = group[k];
284 MS_EXCEPTION_IF_NULL(main);
285
286 // When all node in group has been replaced
287 // or a valuenode node, skip compare in group
288 if ((k + 1 + clear_set.size() == group.size()) || (k > 0 && main->isa<ValueNode>())) {
289 break;
290 }
291
292 // skip node has been replaced
293 if (clear_set.find(k) != clear_set.end()) {
294 continue;
295 }
296
297 // Compare with rest elements in this group.
298 for (size_t i = k + 1; i < group.size(); i++) {
299 auto node = group[i];
300 MS_EXCEPTION_IF_NULL(node);
301
302 if (clear_set.find(i) != clear_set.end()) {
303 continue;
304 }
305 if (main->func_graph() != node->func_graph()) {
306 continue;
307 }
308 if (CheckReplace(node, main)) {
309 changes = true;
310 UpdateDebugInfoAndDumpFlag(main, node);
311 MS_LOG(DEBUG) << "Add replicated_nodes_, node:" << node->DebugString() << ", main:" << main->DebugString();
312 AddReplicatedNode(node, main);
313 (void)clear_set.insert(i);
314 }
315 }
316 }
317 clear_set.clear();
318 }
319 }
320 return changes;
321 }
322
Init()323 void CSE::Init() {
324 hidden_side_effect_func_graphs_.clear();
325 replicated_nodes_.clear();
326 }
327
Cse(const FuncGraphPtr root,const FuncGraphManagerPtr manager)328 bool CSE::Cse(const FuncGraphPtr root, const FuncGraphManagerPtr manager) {
329 MS_EXCEPTION_IF_NULL(manager);
330 Init();
331 manager->AddFuncGraph(root);
332 return BuildOrderGroupAndDoReplace(manager);
333 }
334 } // namespace opt
335 } // namespace mindspore
336