1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/graph_utils.h"
20
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <stack>
25 #include <vector>
26 #include <tuple>
27 #include <string>
28 #include <fstream>
29 #include <deque>
30 #include <set>
31
32 #include "ir/func_graph.h"
33 #include "utils/log_adapter.h"
34 #include "utils/ms_context.h"
35 #include "mindspore/ccsrc/utils/utils.h"
36
37 namespace mindspore {
38 // Dump the circle from the strike node `next`.
DumpSortingCircleList(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,size_t seen)39 static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, size_t seen) {
40 size_t pos = 0;
41 auto circle_node_it = std::find(todo.begin(), todo.end(), next);
42 for (; circle_node_it != todo.end(); circle_node_it++) {
43 auto circle_node = *circle_node_it;
44 if (circle_node->seen_ == seen) {
45 MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString();
46 pos++;
47 }
48 }
49 return pos;
50 }
51
TopoSort(const AnfNodePtr & root,const SuccFunc & succ,const IncludeFunc & include)52 std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
53 std::vector<AnfNodePtr> res;
54 if (root == nullptr) {
55 return res;
56 }
57 size_t seen = NewSeenGeneration();
58 std::deque<AnfNodePtr> todo;
59 todo.push_back(root);
60
61 while (!todo.empty()) {
62 AnfNodePtr node = todo.back();
63 if (node->extra_seen_ == seen) { // We use extra_seen_ as finish flag
64 todo.pop_back();
65 continue;
66 }
67 auto incl = include(node);
68 if (node->seen_ == seen) { // We use seen_ as checking flag
69 todo.pop_back();
70 if (incl != EXCLUDE) {
71 res.push_back(node);
72 }
73 node->extra_seen_ = seen;
74 continue;
75 }
76 node->seen_ = seen;
77 if (incl == FOLLOW) {
78 auto succs = succ(node);
79 (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, &todo](const AnfNodePtr &next) {
80 if (next == nullptr || next->extra_seen_ == seen) {
81 return false;
82 }
83 if (next->seen_ != seen) {
84 return true;
85 }
86 if (next->func_graph() != nullptr && next->func_graph()->get_return() == next) {
87 return false;
88 }
89 // To dump all nodes in a circle.
90 MS_LOG(ERROR) << "Graph cycle exists. Circle is: ";
91 auto circle_len = DumpSortingCircleList(todo, next, seen);
92 MS_LOG(EXCEPTION) << "Graph cycle exists, size: " << circle_len << ", strike node: " << next->DebugString(2);
93 });
94 } else if (incl > EXCLUDE) { // Not NOFOLLOW or EXCLUDE
95 MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
96 }
97 }
98 return res;
99 }
100
101 // search the cnodes inside this graph only
BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> & starts)102 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) {
103 std::vector<CNodePtr> todo;
104 todo.insert(todo.end(), starts.begin(), starts.end());
105 auto seen = NewSeenGeneration();
106 size_t top_idx = 0;
107 while (top_idx < todo.size()) {
108 CNodePtr top = todo[top_idx];
109 top_idx++;
110 auto inputs = top->inputs();
111 for (auto &item : inputs) {
112 if (item->seen_ == seen) {
113 continue;
114 }
115
116 if (item->isa<CNode>()) {
117 todo.push_back(item->cast<CNodePtr>());
118 }
119 item->seen_ = seen;
120 }
121 }
122 return todo;
123 }
124
125 // search the cnode match the predicate inside this graph only
BroadFirstSearchFirstOf(const std::vector<CNodePtr> & starts,const MatchFunc & match_predicate)126 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate) {
127 std::deque<CNodePtr> todo;
128 todo.insert(todo.end(), starts.begin(), starts.end());
129 auto seen = NewSeenGeneration();
130 while (!todo.empty()) {
131 CNodePtr top = todo.front();
132 todo.pop_front();
133 if (match_predicate(top)) {
134 return top;
135 }
136 auto inputs = top->inputs();
137 for (auto &item : inputs) {
138 if (item->seen_ == seen) {
139 continue;
140 }
141
142 if (item->isa<CNode>()) {
143 todo.push_back(item->cast<CNodePtr>());
144 }
145 item->seen_ = seen;
146 }
147 }
148 return nullptr;
149 }
150
BroadFirstSearchGraphUsed(const FuncGraphPtr & root)151 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root) {
152 std::vector<FuncGraphPtr> todo;
153 todo.push_back(root);
154 auto seen = NewSeenGeneration();
155 size_t top_idx = 0;
156 while (top_idx < todo.size()) {
157 FuncGraphPtr top = todo[top_idx];
158 top_idx++;
159 auto used = top->func_graphs_used();
160 for (auto &item : used) {
161 if (item.first->seen_ == seen) {
162 continue;
163 }
164 todo.push_back(item.first);
165 item.first->seen_ = seen;
166 }
167 }
168 return todo;
169 }
170
171 // PushSuccessors push cnode inputs to a vector as successors for topo sort.
PushSuccessors(const CNodePtr & cnode,std::vector<AnfNodePtr> * vecs)172 static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
173 auto &inputs = cnode->inputs();
174 vecs->reserve(vecs->size() + inputs.size());
175
176 // To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
177 auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
178 auto sort_rhs_first =
179 attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
180 if (sort_rhs_first) {
181 vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
182 } else {
183 vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
184 }
185 }
186
SuccDeeper(const AnfNodePtr & node)187 std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
188 std::vector<AnfNodePtr> vecs;
189 if (node == nullptr) {
190 return vecs;
191 }
192
193 if (IsValueNode<FuncGraph>(node)) {
194 auto graph = GetValueNode<FuncGraphPtr>(node);
195 auto ret = graph->get_return();
196 if (ret != nullptr) {
197 vecs.push_back(ret);
198 }
199 return vecs;
200 } else if (node->func_graph() != nullptr) {
201 if (node->isa<CNode>()) {
202 PushSuccessors(node->cast<CNodePtr>(), &vecs);
203 }
204 return vecs;
205 }
206
207 return vecs;
208 }
209
SuccDeeperSimple(const AnfNodePtr & node)210 std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
211 std::vector<AnfNodePtr> vecs;
212 if (node == nullptr) {
213 return vecs;
214 }
215
216 if (IsValueNode<FuncGraph>(node)) {
217 auto graph = GetValueNode<FuncGraphPtr>(node);
218 auto ret = graph->get_return();
219 if (ret != nullptr) {
220 vecs.push_back(ret);
221 }
222 return vecs;
223 } else {
224 if (node->isa<CNode>()) {
225 PushSuccessors(node->cast<CNodePtr>(), &vecs);
226 }
227 return vecs;
228 }
229 }
230
SuccIncoming(const AnfNodePtr & node)231 std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
232 std::vector<AnfNodePtr> vecs;
233 auto cnode = dyn_cast<CNode>(node);
234 if (cnode != nullptr) {
235 PushSuccessors(cnode, &vecs);
236 }
237 return vecs;
238 }
239
SuccIncludeFV(const FuncGraphPtr & fg,const AnfNodePtr & node)240 std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
241 std::vector<AnfNodePtr> vecs;
242 if (node == nullptr) {
243 return vecs;
244 }
245 if (node->isa<CNode>()) {
246 auto cnode = node->cast<CNodePtr>();
247 auto &inputs = cnode->inputs();
248 // Check if free variables used.
249 for (const auto &input : inputs) {
250 auto input_fg = GetValueNode<FuncGraphPtr>(input);
251 if (input_fg) {
252 for (auto &fv : input_fg->free_variables_nodes()) {
253 if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
254 vecs.push_back(fv);
255 }
256 }
257 }
258 }
259 PushSuccessors(cnode, &vecs);
260 }
261 return vecs;
262 }
263
GetInputs(const AnfNodePtr & node)264 const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
265 static std::vector<AnfNodePtr> empty_inputs;
266 auto cnode = dyn_cast<CNode>(node);
267 if (cnode != nullptr) {
268 return cnode->inputs();
269 }
270 return empty_inputs;
271 }
272
AlwaysInclude(const AnfNodePtr &)273 IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; }
274
IncludeBelongGraph(const FuncGraphPtr & fg,const AnfNodePtr & node)275 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
276 if (node->func_graph() == fg) {
277 return FOLLOW;
278 } else {
279 return EXCLUDE;
280 }
281 }
282
FuncGraphIndex(const FuncGraphPtr & fg,const SearchFunc & search,const IncludeFunc & include)283 FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) {
284 MS_EXCEPTION_IF_NULL(fg);
285 Acquire(fg);
286
287 auto vec = search(fg->get_return(), include);
288 for (auto &node : vec) {
289 MS_EXCEPTION_IF_NULL(node);
290 Acquire(node);
291 if (node->func_graph() != nullptr) {
292 Acquire(node->func_graph());
293 }
294 }
295 }
296
GetFuncGraphs(const std::string & key)297 std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) {
298 std::set<FuncGraphPtr> func_graphs;
299 if (index_func_graph_.find(key) != index_func_graph_.end()) {
300 func_graphs = index_func_graph_[key];
301 }
302 return func_graphs;
303 }
304
GetNodes(const std::string & key)305 std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) {
306 if (index_node_.find(key) != index_node_.end()) {
307 return index_node_[key];
308 }
309
310 return std::set<AnfNodePtr>();
311 }
312
GetFirstFuncGraph(const std::string & key)313 FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) {
314 if (GetFuncGraphs(key).empty()) {
315 return nullptr;
316 }
317
318 auto fg = *GetFuncGraphs(key).begin();
319 return fg;
320 }
321
GetFirstNode(const std::string & key)322 AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) {
323 if (GetNodes(key).empty()) {
324 return nullptr;
325 }
326
327 auto node = *GetNodes(key).begin();
328 return node;
329 }
330
Acquire(const FuncGraphPtr & key)331 void FuncGraphIndex::Acquire(const FuncGraphPtr &key) {
332 std::string name = label_manage::Label(key->debug_info());
333 if (!name.empty()) {
334 (void)index_func_graph_[name].insert(key);
335 }
336 }
337
Acquire(const AnfNodePtr & key)338 void FuncGraphIndex::Acquire(const AnfNodePtr &key) {
339 std::string name = label_manage::Label(key->debug_info());
340 if (!name.empty()) {
341 (void)index_node_[name].insert(key);
342 }
343 }
344 } // namespace mindspore
345