• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/graph_utils.h"
20 
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <stack>
25 #include <vector>
26 #include <tuple>
27 #include <string>
28 #include <fstream>
29 #include <deque>
30 #include <set>
31 
32 #include "ir/func_graph.h"
33 #include "utils/log_adapter.h"
34 #include "utils/ms_context.h"
35 #include "mindspore/ccsrc/utils/utils.h"
36 
37 namespace mindspore {
38 // Dump the circle from the strike node `next`.
DumpSortingCircleList(const std::deque<AnfNodePtr> & todo,const AnfNodePtr & next,size_t seen)39 static size_t DumpSortingCircleList(const std::deque<AnfNodePtr> &todo, const AnfNodePtr &next, size_t seen) {
40   size_t pos = 0;
41   auto circle_node_it = std::find(todo.begin(), todo.end(), next);
42   for (; circle_node_it != todo.end(); circle_node_it++) {
43     auto circle_node = *circle_node_it;
44     if (circle_node->seen_ == seen) {
45       MS_LOG(ERROR) << "#" << pos << ": " << circle_node->DebugString();
46       pos++;
47     }
48   }
49   return pos;
50 }
51 
TopoSort(const AnfNodePtr & root,const SuccFunc & succ,const IncludeFunc & include)52 std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) {
53   std::vector<AnfNodePtr> res;
54   if (root == nullptr) {
55     return res;
56   }
57   size_t seen = NewSeenGeneration();
58   std::deque<AnfNodePtr> todo;
59   todo.push_back(root);
60 
61   while (!todo.empty()) {
62     AnfNodePtr node = todo.back();
63     if (node->extra_seen_ == seen) {  // We use extra_seen_ as finish flag
64       todo.pop_back();
65       continue;
66     }
67     auto incl = include(node);
68     if (node->seen_ == seen) {  // We use seen_ as checking flag
69       todo.pop_back();
70       if (incl != EXCLUDE) {
71         res.push_back(node);
72       }
73       node->extra_seen_ = seen;
74       continue;
75     }
76     node->seen_ = seen;
77     if (incl == FOLLOW) {
78       auto succs = succ(node);
79       (void)std::copy_if(succs.begin(), succs.end(), std::back_inserter(todo), [seen, &todo](const AnfNodePtr &next) {
80         if (next == nullptr || next->extra_seen_ == seen) {
81           return false;
82         }
83         if (next->seen_ != seen) {
84           return true;
85         }
86         if (next->func_graph() != nullptr && next->func_graph()->get_return() == next) {
87           return false;
88         }
89         // To dump all nodes in a circle.
90         MS_LOG(ERROR) << "Graph cycle exists. Circle is: ";
91         auto circle_len = DumpSortingCircleList(todo, next, seen);
92         MS_LOG(EXCEPTION) << "Graph cycle exists, size: " << circle_len << ", strike node: " << next->DebugString(2);
93       });
94     } else if (incl > EXCLUDE) {  // Not NOFOLLOW or EXCLUDE
95       MS_LOG(EXCEPTION) << "The result of include(node) must be one of: \"follow\", \"nofollow\", \"exclude\"";
96     }
97   }
98   return res;
99 }
100 
101 // search the cnodes inside this graph only
BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> & starts)102 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts) {
103   std::vector<CNodePtr> todo;
104   todo.insert(todo.end(), starts.begin(), starts.end());
105   auto seen = NewSeenGeneration();
106   size_t top_idx = 0;
107   while (top_idx < todo.size()) {
108     CNodePtr top = todo[top_idx];
109     top_idx++;
110     auto inputs = top->inputs();
111     for (auto &item : inputs) {
112       if (item->seen_ == seen) {
113         continue;
114       }
115 
116       if (item->isa<CNode>()) {
117         todo.push_back(item->cast<CNodePtr>());
118       }
119       item->seen_ = seen;
120     }
121   }
122   return todo;
123 }
124 
125 // search the cnode match the predicate inside this graph only
BroadFirstSearchFirstOf(const std::vector<CNodePtr> & starts,const MatchFunc & match_predicate)126 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate) {
127   std::deque<CNodePtr> todo;
128   todo.insert(todo.end(), starts.begin(), starts.end());
129   auto seen = NewSeenGeneration();
130   while (!todo.empty()) {
131     CNodePtr top = todo.front();
132     todo.pop_front();
133     if (match_predicate(top)) {
134       return top;
135     }
136     auto inputs = top->inputs();
137     for (auto &item : inputs) {
138       if (item->seen_ == seen) {
139         continue;
140       }
141 
142       if (item->isa<CNode>()) {
143         todo.push_back(item->cast<CNodePtr>());
144       }
145       item->seen_ = seen;
146     }
147   }
148   return nullptr;
149 }
150 
BroadFirstSearchGraphUsed(const FuncGraphPtr & root)151 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root) {
152   std::vector<FuncGraphPtr> todo;
153   todo.push_back(root);
154   auto seen = NewSeenGeneration();
155   size_t top_idx = 0;
156   while (top_idx < todo.size()) {
157     FuncGraphPtr top = todo[top_idx];
158     top_idx++;
159     auto used = top->func_graphs_used();
160     for (auto &item : used) {
161       if (item.first->seen_ == seen) {
162         continue;
163       }
164       todo.push_back(item.first);
165       item.first->seen_ = seen;
166     }
167   }
168   return todo;
169 }
170 
171 // PushSuccessors push cnode inputs to a vector as successors for topo sort.
PushSuccessors(const CNodePtr & cnode,std::vector<AnfNodePtr> * vecs)172 static void PushSuccessors(const CNodePtr &cnode, std::vector<AnfNodePtr> *vecs) {
173   auto &inputs = cnode->inputs();
174   vecs->reserve(vecs->size() + inputs.size());
175 
176   // To keep sort order from left to right in default, if kAttrTopoSortRhsFirst not set.
177   auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
178   auto sort_rhs_first =
179     attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
180   if (sort_rhs_first) {
181     vecs->insert(vecs->end(), inputs.cbegin(), inputs.cend());
182   } else {
183     vecs->insert(vecs->end(), inputs.crbegin(), inputs.crend());
184   }
185 }
186 
SuccDeeper(const AnfNodePtr & node)187 std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node) {
188   std::vector<AnfNodePtr> vecs;
189   if (node == nullptr) {
190     return vecs;
191   }
192 
193   if (IsValueNode<FuncGraph>(node)) {
194     auto graph = GetValueNode<FuncGraphPtr>(node);
195     auto ret = graph->get_return();
196     if (ret != nullptr) {
197       vecs.push_back(ret);
198     }
199     return vecs;
200   } else if (node->func_graph() != nullptr) {
201     if (node->isa<CNode>()) {
202       PushSuccessors(node->cast<CNodePtr>(), &vecs);
203     }
204     return vecs;
205   }
206 
207   return vecs;
208 }
209 
SuccDeeperSimple(const AnfNodePtr & node)210 std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node) {
211   std::vector<AnfNodePtr> vecs;
212   if (node == nullptr) {
213     return vecs;
214   }
215 
216   if (IsValueNode<FuncGraph>(node)) {
217     auto graph = GetValueNode<FuncGraphPtr>(node);
218     auto ret = graph->get_return();
219     if (ret != nullptr) {
220       vecs.push_back(ret);
221     }
222     return vecs;
223   } else {
224     if (node->isa<CNode>()) {
225       PushSuccessors(node->cast<CNodePtr>(), &vecs);
226     }
227     return vecs;
228   }
229 }
230 
SuccIncoming(const AnfNodePtr & node)231 std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node) {
232   std::vector<AnfNodePtr> vecs;
233   auto cnode = dyn_cast<CNode>(node);
234   if (cnode != nullptr) {
235     PushSuccessors(cnode, &vecs);
236   }
237   return vecs;
238 }
239 
SuccIncludeFV(const FuncGraphPtr & fg,const AnfNodePtr & node)240 std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) {
241   std::vector<AnfNodePtr> vecs;
242   if (node == nullptr) {
243     return vecs;
244   }
245   if (node->isa<CNode>()) {
246     auto cnode = node->cast<CNodePtr>();
247     auto &inputs = cnode->inputs();
248     // Check if free variables used.
249     for (const auto &input : inputs) {
250       auto input_fg = GetValueNode<FuncGraphPtr>(input);
251       if (input_fg) {
252         for (auto &fv : input_fg->free_variables_nodes()) {
253           if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
254             vecs.push_back(fv);
255           }
256         }
257       }
258     }
259     PushSuccessors(cnode, &vecs);
260   }
261   return vecs;
262 }
263 
GetInputs(const AnfNodePtr & node)264 const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node) {
265   static std::vector<AnfNodePtr> empty_inputs;
266   auto cnode = dyn_cast<CNode>(node);
267   if (cnode != nullptr) {
268     return cnode->inputs();
269   }
270   return empty_inputs;
271 }
272 
AlwaysInclude(const AnfNodePtr &)273 IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; }
274 
IncludeBelongGraph(const FuncGraphPtr & fg,const AnfNodePtr & node)275 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) {
276   if (node->func_graph() == fg) {
277     return FOLLOW;
278   } else {
279     return EXCLUDE;
280   }
281 }
282 
FuncGraphIndex(const FuncGraphPtr & fg,const SearchFunc & search,const IncludeFunc & include)283 FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) {
284   MS_EXCEPTION_IF_NULL(fg);
285   Acquire(fg);
286 
287   auto vec = search(fg->get_return(), include);
288   for (auto &node : vec) {
289     MS_EXCEPTION_IF_NULL(node);
290     Acquire(node);
291     if (node->func_graph() != nullptr) {
292       Acquire(node->func_graph());
293     }
294   }
295 }
296 
GetFuncGraphs(const std::string & key)297 std::set<FuncGraphPtr> FuncGraphIndex::GetFuncGraphs(const std::string &key) {
298   std::set<FuncGraphPtr> func_graphs;
299   if (index_func_graph_.find(key) != index_func_graph_.end()) {
300     func_graphs = index_func_graph_[key];
301   }
302   return func_graphs;
303 }
304 
GetNodes(const std::string & key)305 std::set<AnfNodePtr> FuncGraphIndex::GetNodes(const std::string &key) {
306   if (index_node_.find(key) != index_node_.end()) {
307     return index_node_[key];
308   }
309 
310   return std::set<AnfNodePtr>();
311 }
312 
GetFirstFuncGraph(const std::string & key)313 FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) {
314   if (GetFuncGraphs(key).empty()) {
315     return nullptr;
316   }
317 
318   auto fg = *GetFuncGraphs(key).begin();
319   return fg;
320 }
321 
GetFirstNode(const std::string & key)322 AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) {
323   if (GetNodes(key).empty()) {
324     return nullptr;
325   }
326 
327   auto node = *GetNodes(key).begin();
328   return node;
329 }
330 
Acquire(const FuncGraphPtr & key)331 void FuncGraphIndex::Acquire(const FuncGraphPtr &key) {
332   std::string name = label_manage::Label(key->debug_info());
333   if (!name.empty()) {
334     (void)index_func_graph_[name].insert(key);
335   }
336 }
337 
Acquire(const AnfNodePtr & key)338 void FuncGraphIndex::Acquire(const AnfNodePtr &key) {
339   std::string name = label_manage::Label(key->debug_info());
340   if (!name.empty()) {
341     (void)index_node_[name].insert(key);
342   }
343 }
344 }  // namespace mindspore
345