• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/graph_utils.h"
20 
21 #include <algorithm>
22 #include <deque>
23 #include <memory>
24 #include <set>
25 #include <utility>
26 
27 #include "ir/anf.h"
28 #include "ir/func_graph.h"
29 #include "utils/hash_map.h"
30 #include "utils/hash_set.h"
31 #include "utils/log_adapter.h"
32 #include "utils/ms_context.h"
33 #include "include/common/utils/utils.h"
34 
35 namespace mindspore {
36 namespace {
37 // Dump the circle from the strike node `next`.
DumpSortingCircleList(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,SeenNum seen)38 static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, SeenNum seen) {
39   size_t pos = 0;
40   auto circle_node_it = std::find(todo.begin(), todo.end(), next);
41   for (; circle_node_it != todo.end(); ++circle_node_it) {
42     auto circle_node = *circle_node_it;
43     if (circle_node->seen_ == seen) {
44       MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString();
45       ++pos;
46     }
47   }
48   return pos;
49 }
50 
51 static DumpIRPrividerFunction dump_ir_privider{nullptr};
DumpIRPrivider()52 DumpIRPrividerFunction DumpIRPrivider() { return dump_ir_privider; }
53 
54 static DumpIRStorageFunction dump_ir_storage{nullptr};
DumpIRStorage()55 DumpIRStorageFunction DumpIRStorage() { return dump_ir_storage; }
56 
57 // DumpIR for all func graphs in the circle, and print circle indicators in the IR file.
DumpSortingCircleIr(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,SeenNum seen)58 void DumpSortingCircleIr(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, SeenNum seen) {
59   if (DumpIRPrivider() == nullptr || DumpIRStorage() == nullptr) {
60     MS_LOG(DEBUG) << "DumpIR privider is null";
61     return;
62   }
63   std::set<FuncGraphPtr> func_graph_set;
64   size_t pos = 0;
65   auto circle_node_it = std::find(todo.begin(), todo.end(), next);
66   for (; circle_node_it != todo.end(); ++circle_node_it) {
67     auto circle_node = *circle_node_it;
68     if (circle_node->seen_ == seen) {
69       if (circle_node->func_graph() != nullptr && func_graph_set.count(circle_node->func_graph()) == 0) {
70         (void)func_graph_set.emplace(circle_node->func_graph());
71       }
72       circle_node->set_user_data<size_t>(kTopoSortCircle, std::make_shared<size_t>(pos));
73       ++pos;
74     }
75   }
76   if (func_graph_set.empty()) {
77     MS_LOG(ERROR) << "At least one func graph if there's a TopoSort circle.";
78     return;
79   }
80   std::ostringstream graph_buffer;
81   graph_buffer << "# ===========================================================================\n"
82                << "# Graph cycle exists during TopoSort.\n"
83                << "# Total graphs: " << func_graph_set.size() << "\n#\n"
84                << "# You can search ------------------------> " << (pos - 1) << ",\n"
85                << "# to locate the node who leads to the circle.\n"
86                << "# ===========================================================================\n\n";
87   for (const auto &graph : func_graph_set) {
88     DumpIRPrivider()(graph_buffer, graph, false, 0, true);
89   }
90   DumpIRStorage()("TOPO_SORT_CIRCLE_GRAPHS_" + std::to_string(func_graph_set.size()) + ".ir", graph_buffer.str(), "");
91 }
92 }  // namespace
93 
SetDumpIRPrivider(const DumpIRPrividerFunction & func)94 void SetDumpIRPrivider(const DumpIRPrividerFunction &func) { dump_ir_privider = func; }
95 
SetDumpIRStorage(const DumpIRStorageFunction & func)96 void SetDumpIRStorage(const DumpIRStorageFunction &func) { dump_ir_storage = func; }
97 
TopoSort(const AnfNodePtr & root,const SuccFunc & succ,const IncludeFunc & include,bool exclude_circle_node)98 AnfNodePtrList TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include,
99                         bool exclude_circle_node) {
100   AnfNodePtrList res;
101   if (root == nullptr) {
102     return res;
103   }
104   constexpr auto vector_reserve_size = 64;
105   res.reserve(vector_reserve_size);
106   auto seen = NewSeenGeneration();
107   std::deque<AnfNodePtr> todo;
108   (void)todo.emplace_back(root);
109   while (!todo.empty()) {
110     AnfNodePtr &node = todo.back();
111     if (node->extra_seen_ == seen) {  // We use extra_seen_ as finish flag
112       todo.pop_back();
113       continue;
114     }
115     auto incl = include(node);
116     if (node->seen_ == seen) {  // We use seen_ as checking flag
117       node->extra_seen_ = seen;
118       if (incl != EXCLUDE) {
119         (void)res.emplace_back(std::move(node));
120       }
121       todo.pop_back();
122       continue;
123     }
124     node->seen_ = seen;
125     if (incl == FOLLOW) {
126       for (auto &weak_next : succ(node)) {
127         auto next = weak_next.lock();
128         if (next == nullptr || next->extra_seen_ == seen) {
129           continue;
130         }
131         if (next->seen_ != seen) {
132           (void)todo.emplace_back(std::move(next));
133           continue;
134         }
135         auto fg = next->func_graph();
136         if (fg != nullptr && fg->return_node() == next) {
137           continue;
138         }
139         constexpr auto recursive_level = 2;
140         if (exclude_circle_node) {
141           MS_LOG(INFO) << "Graph cycle exists, exclude circle strike node: " << next->DebugString(recursive_level);
142           continue;
143         }
144         // To dump all nodes in a circle.
145         MS_LOG(ERROR) << "Graph cycle exists, strike node: " << next->DebugString(recursive_level) << "\nCircle is: ";
146         auto circle_len = DumpSortingCircleList(todo, next, seen);
147         DumpSortingCircleIr(todo, next, seen);
148         MS_LOG(INTERNAL_EXCEPTION) << "Graph cycle exists, size: " << circle_len
149                                    << ", strike node: " << next->DebugString(recursive_level);
150       }
151     } else if (incl > EXCLUDE) {  // Not NOFOLLOW or EXCLUDE
152       MS_LOG(INTERNAL_EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
153     }
154   }
155   return res;
156 }
157 
158 // @deprecated
159 // To use 'AnfNodePtrList TopoSort(const AnfNodePtr &, const SuccFunc &, const IncludeFunc &, bool)' instead.
TopoSort(const AnfNodePtr & root,const DeprecatedSuccFunc & deprecated_succ,const IncludeFunc & include,bool exclude_circle_node)160 AnfNodePtrList TopoSort(const AnfNodePtr &root, const DeprecatedSuccFunc &deprecated_succ, const IncludeFunc &include,
161                         bool exclude_circle_node) {
162   SuccFunc compatible_adapter_succ = [&deprecated_succ](const AnfNodePtr &node) -> AnfNodeWeakPtrList {
163     auto nodes = deprecated_succ(node);
164     AnfNodeWeakPtrList weak_nodes;
165     weak_nodes.reserve(nodes.size());
166     std::transform(nodes.cbegin(), nodes.cend(), std::back_inserter(weak_nodes),
167                    [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
168     return weak_nodes;
169   };
170   return TopoSort(root, compatible_adapter_succ, include, exclude_circle_node);
171 }
172 
173 // Search all CNode in root's graph only.
BroadFirstSearchGraphCNodes(const CNodePtr & root)174 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const CNodePtr &root) {
175   constexpr size_t kVecReserve = 64;
176   std::vector<CNodePtr> cnodes;
177   cnodes.reserve(kVecReserve);
178   auto seen = NewSeenGeneration();
179   MS_EXCEPTION_IF_NULL(root);
180   root->seen_ = seen;
181   (void)cnodes.emplace_back(root);
182   for (size_t i = 0; i < cnodes.size(); ++i) {
183     CNodePtr &node = cnodes[i];
184     for (auto &weak_input : node->weak_inputs()) {
185       auto input = weak_input.lock();
186       if (input == nullptr) {
187         MS_LOG(INTERNAL_EXCEPTION) << "The input is null, node: " << node << "/" << node->DebugString();
188       }
189       if (input->seen_ == seen) {
190         continue;
191       }
192       input->seen_ = seen;
193       auto input_cnode = input->cast<CNodePtr>();
194       if (input_cnode != nullptr) {
195         (void)cnodes.emplace_back(std::move(input_cnode));
196       }
197     }
198   }
199   return cnodes;
200 }
201 
202 // Search all CNode match the predicate in roots' graph only.
BroadFirstSearchFirstOf(const std::vector<CNodePtr> & roots,const MatchFunc & match_predicate)203 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &roots, const MatchFunc &match_predicate) {
204   std::deque<CNodePtr> todo;
205   (void)todo.insert(todo.end(), roots.begin(), roots.end());
206   auto seen = NewSeenGeneration();
207   while (!todo.empty()) {
208     CNodePtr top = todo.front();
209     todo.pop_front();
210     if (match_predicate(top)) {
211       return top;
212     }
213     for (auto &weak_input : top->weak_inputs()) {
214       auto input = weak_input.lock();
215       MS_EXCEPTION_IF_NULL(input);
216       if (input->seen_ == seen) {
217         continue;
218       }
219 
220       if (input->isa<CNode>()) {
221         todo.push_back(input->cast<CNodePtr>());
222       }
223       input->seen_ = seen;
224     }
225   }
226   return nullptr;
227 }
228 
BroadFirstSearchGraphUsed(const FuncGraphPtr & root,const GraphFilterFunc & filter)229 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root, const GraphFilterFunc &filter) {
230   std::vector<FuncGraphPtr> todo;
231   todo.push_back(root);
232   auto seen = NewSeenGeneration();
233   size_t top_idx = 0;
234   while (top_idx < todo.size()) {
235     FuncGraphPtr top = todo[top_idx];
236     top_idx++;
237     auto used = top->func_graphs_used();
238     for (auto &item : used) {
239       if (item.first->seen_ == seen) {
240         continue;
241       }
242       if (filter && filter(item.first)) {
243         continue;
244       }
245       todo.push_back(item.first);
246       item.first->seen_ = seen;
247     }
248   }
249   return todo;
250 }
251 
252 // To get CNode inputs to a vector as successors for TopoSort().
FetchCNodeSuccessors(const CNodePtr & cnode,AnfNodeWeakPtrList * vecs)253 static void FetchCNodeSuccessors(const CNodePtr &cnode, AnfNodeWeakPtrList *vecs) {
254   auto &inputs = cnode->weak_inputs();
255   vecs->reserve(vecs->size() + inputs.size());
256 
257   // To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
258   auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
259   auto sort_rhs_first =
260     attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
261   if (sort_rhs_first) {
262     (void)vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
263   } else {
264     (void)vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
265   }
266 }
267 
SuccDeeperSimple(const AnfNodePtr & node)268 AnfNodeWeakPtrList SuccDeeperSimple(const AnfNodePtr &node) {
269   AnfNodeWeakPtrList vecs;
270   if (node == nullptr) {
271     return vecs;
272   }
273 
274   auto graph = GetValuePtr<FuncGraph>(node);
275   if (graph != nullptr) {
276     auto &res = graph->return_node();
277     if (res != nullptr) {
278       vecs.push_back(res);
279     }
280   } else if (node->isa<CNode>()) {
281     FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
282   }
283   return vecs;
284 }
285 
SuccIncoming(const AnfNodePtr & node)286 AnfNodeWeakPtrList SuccIncoming(const AnfNodePtr &node) {
287   AnfNodeWeakPtrList vecs;
288   auto cnode = dyn_cast<CNode>(node);
289   if (cnode != nullptr) {
290     FetchCNodeSuccessors(cnode, &vecs);
291   }
292   return vecs;
293 }
294 
SuccIncludeFV(const FuncGraphPtr & fg,const AnfNodePtr & node)295 AnfNodeWeakPtrList SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
296   auto cnode = dyn_cast<CNode>(node);
297   if (cnode == nullptr) {
298     return {};
299   }
300   AnfNodeWeakPtrList vecs;
301   const auto &inputs = cnode->inputs();
302   // Check if free variables used.
303   for (const auto &input : inputs) {
304     auto input_fg = GetValuePtr<FuncGraph>(input);
305     if (input_fg != nullptr) {
306       for (auto &fv : input_fg->free_variables_nodes()) {
307         MS_EXCEPTION_IF_NULL(fv);
308         if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
309           vecs.push_back(fv);
310         }
311       }
312     }
313   }
314   FetchCNodeSuccessors(cnode, &vecs);
315   return vecs;
316 }
317 
SuccWithFilter(const GraphFilterFunc & graph_filter,const AnfNodePtr & node)318 AnfNodeWeakPtrList SuccWithFilter(const GraphFilterFunc &graph_filter, const AnfNodePtr &node) {
319   AnfNodeWeakPtrList vecs;
320   if (node == nullptr) {
321     return vecs;
322   }
323 
324   auto graph = GetValueNode<FuncGraphPtr>(node);
325   if (graph != nullptr) {
326     if (graph_filter != nullptr && graph_filter(graph)) {
327       return vecs;
328     }
329     auto &res = graph->return_node();
330     if (res != nullptr) {
331       vecs.push_back(res);
332     }
333   } else if (node->isa<CNode>()) {
334     FetchCNodeSuccessors(node->cast<CNodePtr>(), &vecs);
335   }
336   return vecs;
337 }
338 
GetInputs(const AnfNodePtr & node)339 const AnfNodePtrList GetInputs(const AnfNodePtr &node) {
340   static AnfNodePtrList empty_inputs;
341   auto cnode = dyn_cast_ptr<CNode>(node);
342   if (cnode != nullptr) {
343     return cnode->inputs();
344   }
345   return empty_inputs;
346 }
347 
GetWeakInputs(const AnfNodePtr & node)348 const AnfNodeWeakPtrList &GetWeakInputs(const AnfNodePtr &node) {
349   static AnfNodeWeakPtrList empty_inputs;
350   auto cnode = dyn_cast_ptr<CNode>(node);
351   if (cnode != nullptr) {
352     return cnode->weak_inputs();
353   }
354   return empty_inputs;
355 }
356 
IncludeBelongGraph(const FuncGraphPtr & fg,const AnfNodePtr & node)357 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
358   if (node->func_graph() == fg) {
359     return FOLLOW;
360   } else {
361     return EXCLUDE;
362   }
363 }
364 }  // namespace mindspore
365