1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/graph_utils.h"
20
21 #include <algorithm>
22 #include <deque>
23 #include <memory>
24 #include <set>
25 #include <utility>
26
27 #include "ir/anf.h"
28 #include "ir/func_graph.h"
29 #include "utils/hash_map.h"
30 #include "utils/hash_set.h"
31 #include "utils/log_adapter.h"
32 #include "utils/ms_context.h"
33 #include "include/common/utils/utils.h"
34
35 namespace mindspore {
36 namespace {
37 // Dump the circle from the strike node `next`.
DumpSortingCircleList(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,SeenNum seen)38 static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, SeenNum seen) {
39 size_t pos = 0;
40 auto circle_node_it = std::find(todo.begin(), todo.end(), next);
41 for (; circle_node_it != todo.end(); ++circle_node_it) {
42 auto circle_node = *circle_node_it;
43 if (circle_node->seen_ == seen) {
44 MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString();
45 ++pos;
46 }
47 }
48 return pos;
49 }
50
51 static DumpIRPrividerFunction dump_ir_privider{nullptr};
DumpIRPrivider()52 DumpIRPrividerFunction DumpIRPrivider() { return dump_ir_privider; }
53
54 static DumpIRStorageFunction dump_ir_storage{nullptr};
DumpIRStorage()55 DumpIRStorageFunction DumpIRStorage() { return dump_ir_storage; }
56
57 // DumpIR for all func graphs in the circle, and print circle indicators in the IR file.
DumpSortingCircleIr(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,SeenNum seen)58 void DumpSortingCircleIr(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, SeenNum seen) {
59 if (DumpIRPrivider() == nullptr || DumpIRStorage() == nullptr) {
60 MS_LOG(DEBUG) << "DumpIR privider is null";
61 return;
62 }
63 std::set<FuncGraphPtr> func_graph_set;
64 size_t pos = 0;
65 auto circle_node_it = std::find(todo.begin(), todo.end(), next);
66 for (; circle_node_it != todo.end(); ++circle_node_it) {
67 auto circle_node = *circle_node_it;
68 if (circle_node->seen_ == seen) {
69 if (circle_node->func_graph() != nullptr && func_graph_set.count(circle_node->func_graph()) == 0) {
70 (void)func_graph_set.emplace(circle_node->func_graph());
71 }
72 circle_node->set_user_data<size_t>(kTopoSortCircle, std::make_shared<size_t>(pos));
73 ++pos;
74 }
75 }
76 if (func_graph_set.empty()) {
77 MS_LOG(ERROR) << "At least one func graph if there's a TopoSort circle.";
78 return;
79 }
80 std::ostringstream graph_buffer;
81 graph_buffer << "# ===========================================================================\n"
82 << "# Graph cycle exists during TopoSort.\n"
83 << "# Total graphs: " << func_graph_set.size() << "\n#\n"
84 << "# You can search ------------------------> " << (pos - 1) << ",\n"
85 << "# to locate the node who leads to the circle.\n"
86 << "# ===========================================================================\n\n";
87 for (const auto &graph : func_graph_set) {
88 DumpIRPrivider()(graph_buffer, graph, false, 0, true);
89 }
90 DumpIRStorage()("TOPO_SORT_CIRCLE_GRAPHS_" + std::to_string(func_graph_set.size()) + ".ir", graph_buffer.str(), "");
91 }
92 } // namespace
93
SetDumpIRPrivider(const DumpIRPrividerFunction & func)94 void SetDumpIRPrivider(const DumpIRPrividerFunction &func) { dump_ir_privider = func; }
95
SetDumpIRStorage(const DumpIRStorageFunction & func)96 void SetDumpIRStorage(const DumpIRStorageFunction &func) { dump_ir_storage = func; }
97
TopoSort(const AnfNodePtr & root,const SuccFunc & succ,const IncludeFunc & include,bool exclude_circle_node)98 AnfNodePtrList TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include,
99 bool exclude_circle_node) {
100 AnfNodePtrList res;
101 if (root == nullptr) {
102 return res;
103 }
104 constexpr auto vector_reserve_size = 64;
105 res.reserve(vector_reserve_size);
106 auto seen = NewSeenGeneration();
107 std::deque<AnfNodePtr> todo;
108 (void)todo.emplace_back(root);
109 while (!todo.empty()) {
110 AnfNodePtr &node = todo.back();
111 if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag
112 todo.pop_back();
113 continue;
114 }
115 auto incl = include(node);
116 if (node->seen_ == seen) { // We use seen_ as checking flag
117 node->extra_seen_ = seen;
118 if (incl != EXCLUDE) {
119 (void)res.emplace_back(std::move(node));
120 }
121 todo.pop_back();
122 continue;
123 }
124 node->seen_ = seen;
125 if (incl == FOLLOW) {
126 for (auto &weak_next : succ(node)) {
127 auto next = weak_next.lock();
128 if (next == nullptr || next->extra_seen_ == seen) {
129 continue;
130 }
131 if (next->seen_ != seen) {
132 (void)todo.emplace_back(std::move(next));
133 continue;
134 }
135 auto fg = next->func_graph();
136 if (fg != nullptr && fg->return_node() == next) {
137 continue;
138 }
139 constexpr auto recursive_level = 2;
140 if (exclude_circle_node) {
141 MS_LOG(INFO) << "Graph cycle exists, exclude circle strike node: " << next->DebugString(recursive_level);
142 continue;
143 }
144 // To dump all nodes in a circle.
145 MS_LOG(ERROR) << "Graph cycle exists, strike node: " << next->DebugString(recursive_level) << "\nCircle is: ";
146 auto circle_len = DumpSortingCircleList(todo, next, seen);
147 DumpSortingCircleIr(todo, next, seen);
148 MS_LOG(INTERNAL_EXCEPTION) << "Graph cycle exists, size: " << circle_len
149 << ", strike node: " << next->DebugString(recursive_level);
150 }
151 } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
152 MS_LOG(INTERNAL_EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
153 }
154 }
155 return res;
156 }
157
158 // @deprecated
159 // To use 'AnfNodePtrList TopoSort(const AnfNodePtr &, const SuccFunc &, const IncludeFunc &, bool)' instead.
TopoSort(const AnfNodePtr & root,const DeprecatedSuccFunc & deprecated_succ,const IncludeFunc & include,bool exclude_circle_node)160 AnfNodePtrList TopoSort(const AnfNodePtr &root, const DeprecatedSuccFunc &deprecated_succ, const IncludeFunc &include,
161 bool exclude_circle_node) {
162 SuccFunc compatible_adapter_succ = [&deprecated_succ](const AnfNodePtr &node) -> AnfNodeWeakPtrList {
163 auto nodes = deprecated_succ(node);
164 AnfNodeWeakPtrList weak_nodes;
165 weak_nodes.reserve(nodes.size());
166 std::transform(nodes.cbegin(), nodes.cend(), std::back_inserter(weak_nodes),
167 [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
168 return weak_nodes;
169 };
170 return TopoSort(root, compatible_adapter_succ, include, exclude_circle_node);
171 }
172
173 // Search all CNode in root's graph only.
BroadFirstSearchGraphCNodes(const CNodePtr & root)174 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &root) {
175 constexpr size_t kVecReserve = 64;
176 std::vector<CNodePtr> cnodes;
177 cnodes.reserve(kVecReserve);
178 auto seen = NewSeenGeneration();
179 MS_EXCEPTION_IF_NULL(root);
180 root->seen_ = seen;
181 (void)cnodes.emplace_back(root);
182 for (size_t i = 0; i < cnodes.size(); ++i) {
183 CNodePtr &node = cnodes[i];
184 for (auto &weak_input : node->weak_inputs()) {
185 auto input = weak_input.lock();
186 if (input == nullptr) {
187 MS_LOG(INTERNAL_EXCEPTION) << "The input is null, node: " << node << "/" << node->DebugString();
188 }
189 if (input->seen_ == seen) {
190 continue;
191 }
192 input->seen_ = seen;
193 auto input_cnode = input->cast<CNodePtr>();
194 if (input_cnode != nullptr) {
195 (void)cnodes.emplace_back(std::move(input_cnode));
196 }
197 }
198 }
199 return cnodes;
200 }
201
202 // Search all CNode match the predicate in roots' graph only.
BroadFirstSearchFirstOf(const std::vector<CNodePtr> & roots,const MatchFunc & match_predicate)203 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &roots, const MatchFunc &match_predicate) {
204 std::deque<CNodePtr> todo;
205 (void)todo.insert(todo.end(), roots.begin(), roots.end());
206 auto seen = NewSeenGeneration();
207 while (!todo.empty()) {
208 CNodePtr top = todo.front();
209 todo.pop_front();
210 if (match_predicate(top)) {
211 return top;
212 }
213 for (auto &weak_input : top->weak_inputs()) {
214 auto input = weak_input.lock();
215 MS_EXCEPTION_IF_NULL(input);
216 if (input->seen_ == seen) {
217 continue;
218 }
219
220 if (input->isa<CNode>()) {
221 todo.push_back(input->cast<CNodePtr>());
222 }
223 input->seen_ = seen;
224 }
225 }
226 return nullptr;
227 }
228
BroadFirstSearchGraphUsed(const FuncGraphPtr & root,const GraphFilterFunc & filter)229 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root, const GraphFilterFunc &filter) {
230 std::vector<FuncGraphPtr> todo;
231 todo.push_back(root);
232 auto seen = NewSeenGeneration();
233 size_t top_idx = 0;
234 while (top_idx < todo.size()) {
235 FuncGraphPtr top = todo[top_idx];
236 top_idx++;
237 auto used = top->func_graphs_used();
238 for (auto &item : used) {
239 if (item.first->seen_ == seen) {
240 continue;
241 }
242 if (filter && filter(item.first)) {
243 continue;
244 }
245 todo.push_back(item.first);
246 item.first->seen_ = seen;
247 }
248 }
249 return todo;
250 }
251
252 // To get CNode inputs to a vector as successors for TopoSort().
FetchCNodeSuccessors(const CNodePtr & cnode,AnfNodeWeakPtrList * vecs)253 static void FetchCNodeSuccessors(const CNodePtr &cnode, AnfNodeWeakPtrList *vecs) {
254 auto &inputs = cnode->weak_inputs();
255 vecs->reserve(vecs->size() + inputs.size());
256
257 // To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
258 auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
259 auto sort_rhs_first =
260 attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
261 if (sort_rhs_first) {
262 (void)vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
263 } else {
264 (void)vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
265 }
266 }
267
SuccDeeperSimple(const AnfNodePtr & node)268 AnfNodeWeakPtrList SuccDeeperSimple(const AnfNodePtr &node) {
269 AnfNodeWeakPtrList vecs;
270 if (node == nullptr) {
271 return vecs;
272 }
273
274 auto graph = GetValuePtr<FuncGraph>(node);
275 if (graph != nullptr) {
276 auto &res = graph->return_node();
277 if (res != nullptr) {
278 vecs.push_back(res);
279 }
280 } else if (node->isa<CNode>()) {
281 FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
282 }
283 return vecs;
284 }
285
SuccIncoming(const AnfNodePtr & node)286 AnfNodeWeakPtrList SuccIncoming(const AnfNodePtr &node) {
287 AnfNodeWeakPtrList vecs;
288 auto cnode = dyn_cast<CNode>(node);
289 if (cnode != nullptr) {
290 FetchCNodeSuccessors(cnode, &vecs);
291 }
292 return vecs;
293 }
294
SuccIncludeFV(const FuncGraphPtr & fg,const AnfNodePtr & node)295 AnfNodeWeakPtrList SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
296 auto cnode = dyn_cast<CNode>(node);
297 if (cnode == nullptr) {
298 return {};
299 }
300 AnfNodeWeakPtrList vecs;
301 const auto &inputs = cnode->inputs();
302 // Check if free variables used.
303 for (const auto &input : inputs) {
304 auto input_fg = GetValuePtr<FuncGraph>(input);
305 if (input_fg != nullptr) {
306 for (auto &fv : input_fg->free_variables_nodes()) {
307 MS_EXCEPTION_IF_NULL(fv);
308 if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
309 vecs.push_back(fv);
310 }
311 }
312 }
313 }
314 FetchCNodeSuccessors(cnode, &vecs);
315 return vecs;
316 }
317
SuccWithFilter(const GraphFilterFunc & graph_filter,const AnfNodePtr & node)318 AnfNodeWeakPtrList SuccWithFilter(const GraphFilterFunc &graph_filter, const AnfNodePtr &node) {
319 AnfNodeWeakPtrList vecs;
320 if (node == nullptr) {
321 return vecs;
322 }
323
324 auto graph = GetValueNode<FuncGraphPtr>(node);
325 if (graph != nullptr) {
326 if (graph_filter != nullptr && graph_filter(graph)) {
327 return vecs;
328 }
329 auto &res = graph->return_node();
330 if (res != nullptr) {
331 vecs.push_back(res);
332 }
333 } else if (node->isa<CNode>()) {
334 FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
335 }
336 return vecs;
337 }
338
GetInputs(const AnfNodePtr & node)339 const AnfNodePtrList GetInputs(const AnfNodePtr &node) {
340 static AnfNodePtrList empty_inputs;
341 auto cnode = dyn_cast_ptr<CNode>(node);
342 if (cnode != nullptr) {
343 return cnode->inputs();
344 }
345 return empty_inputs;
346 }
347
GetWeakInputs(const AnfNodePtr & node)348 const AnfNodeWeakPtrList &GetWeakInputs(const AnfNodePtr &node) {
349 static AnfNodeWeakPtrList empty_inputs;
350 auto cnode = dyn_cast_ptr<CNode>(node);
351 if (cnode != nullptr) {
352 return cnode->weak_inputs();
353 }
354 return empty_inputs;
355 }
356
IncludeBelongGraph(const FuncGraphPtr & fg,const AnfNodePtr & node)357 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
358 if (node->func_graph() == fg) {
359 return FOLLOW;
360 } else {
361 return EXCLUDE;
362 }
363 }
364 } // namespace mindspore
365