• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/graph_utils.h"
18 
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <utility>
22 #include <stack>
23 #include <vector>
24 #include <list>
25 #include <string>
26 #include <fstream>
27 
28 #include "ir/visitor.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph.h"
31 #include "utils/label.h"
32 #include "utils/log_adapter.h"
33 #include "utils/ms_utils.h"
34 
35 namespace mindspore {
36 namespace {
37 class DeepFirstSearcher : public AnfIrVisitor {
38  public:
DeepFirstSearcher(const IncludeFunc & include,const FilterFunc & filter=nullptr)39   explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
40       : include_(include), filter_(filter) {}
41   ~DeepFirstSearcher() override = default;
42 
Search(const AnfNodePtr & root)43   std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
44     if (root == nullptr) {
45       return std::move(res_);
46     }
47     seen_ = NewSeenGeneration();
48     Visit(root);
49     return std::move(res_);
50   }
51 
Visit(const AnfNodePtr & node)52   void Visit(const AnfNodePtr &node) override {
53     MS_EXCEPTION_IF_NULL(node);
54     if (node->seen_ == seen_) {
55       return;
56     }
57 
58     node->seen_ = seen_;
59 
60     auto incl = include_(node);
61     if (incl == EXCLUDE) {
62       return;
63     }
64     if (filter_ == nullptr || !filter_(node)) {
65       res_.push_back(node);
66     }
67     if (incl == FOLLOW) {
68       AnfIrVisitor::Visit(node);
69     }
70   }
71 
72  private:
73   size_t seen_{0};
74   IncludeFunc include_;
75   FilterFunc filter_;
76   std::vector<AnfNodePtr> res_{};
77 };
78 
79 class DeepScopedGraphSearcher : public DeepFirstSearcher {
80  public:
DeepScopedGraphSearcher(const IncludeFunc & include)81   explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
82   ~DeepScopedGraphSearcher() override = default;
83 
Visit(const CNodePtr & cnode)84   void Visit(const CNodePtr &cnode) override {
85     if (cnode->func_graph() == nullptr) {
86       return;
87     }
88 
89     AnfNodePtr ret = cnode->func_graph()->get_return();
90     if (ret != nullptr) {
91       DeepFirstSearcher::Visit(ret);
92     }
93 
94     auto &inputs = cnode->inputs();
95     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
96       DeepFirstSearcher::Visit(*iter);
97     }
98   }
99 
Visit(const ValueNodePtr & vnode)100   void Visit(const ValueNodePtr &vnode) override {
101     if (!IsValueNode<FuncGraph>(vnode)) {
102       return;
103     }
104 
105     auto graph = GetValueNode<FuncGraphPtr>(vnode);
106     AnfNodePtr ret = graph->get_return();
107     if (ret != nullptr) {
108       DeepFirstSearcher::Visit(ret);
109     }
110   }
111 
Visit(const ParameterPtr & param)112   void Visit(const ParameterPtr &param) override {
113     if (param->func_graph() == nullptr) {
114       return;
115     }
116 
117     AnfNodePtr ret = param->func_graph()->get_return();
118     if (ret != nullptr) {
119       DeepFirstSearcher::Visit(ret);
120     }
121   }
122 };
123 
124 class DeepUsedGraphSearcher : public DeepFirstSearcher {
125  public:
DeepUsedGraphSearcher(const IncludeFunc & include)126   explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
127   ~DeepUsedGraphSearcher() override = default;
128 
Visit(const CNodePtr & cnode)129   void Visit(const CNodePtr &cnode) override {
130     auto &inputs = cnode->inputs();
131     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
132       DeepFirstSearcher::Visit(*iter);
133     }
134   }
135 
Visit(const ValueNodePtr & vnode)136   void Visit(const ValueNodePtr &vnode) override {
137     if (!IsValueNode<FuncGraph>(vnode)) {
138       return;
139     }
140 
141     auto graph = GetValueNode<FuncGraphPtr>(vnode);
142     AnfNodePtr ret = graph->get_return();
143     if (ret != nullptr) {
144       DeepFirstSearcher::Visit(ret);
145     }
146   }
147 };
148 
149 class DeepLinkedGraphSearcher : public DeepFirstSearcher {
150  public:
DeepLinkedGraphSearcher(const IncludeFunc & include)151   explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
152   ~DeepLinkedGraphSearcher() override = default;
153 
Visit(const CNodePtr & cnode)154   void Visit(const CNodePtr &cnode) override {
155     auto &inputs = cnode->inputs();
156     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
157       DeepFirstSearcher::Visit(*iter);
158     }
159   }
160 
Visit(const ValueNodePtr &)161   void Visit(const ValueNodePtr &) override {}
162 };
163 
164 class DeepUsersSearcher : public DeepFirstSearcher {
165  public:
DeepUsersSearcher(const IncludeFunc & include,const FuncGraphManagerPtr & mng)166   explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng)
167       : DeepFirstSearcher(include), mng_(mng) {}
168   ~DeepUsersSearcher() override = default;
169 
Visit(const CNodePtr & cnode)170   void Visit(const CNodePtr &cnode) override {
171     auto &users = mng_->node_users()[cnode];
172     for (auto iter = users.begin(); iter != users.end(); ++iter) {
173       DeepFirstSearcher::Visit(iter->first);
174     }
175   }
Visit(const ValueNodePtr &)176   void Visit(const ValueNodePtr &) override {}
177 
178  private:
179   FuncGraphManagerPtr mng_;
180 };
181 }  // namespace
182 
183 // include for if expand the node the search, filter for if put the node to results.
DeepScopedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)184 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
185   return DeepScopedGraphSearcher(include).Search(root);
186 }
187 
DeepScopedGraphSearchWithFilter(const AnfNodePtr & root,const IncludeFunc & include,const FilterFunc & filter)188 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
189                                                         const FilterFunc &filter) {
190   return DeepFirstSearcher(include, filter).Search(root);
191 }
192 
DeepUsedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)193 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
194   return DeepUsedGraphSearcher(include).Search(root);
195 }
196 
DeepLinkedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)197 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
198   return DeepLinkedGraphSearcher(include).Search(root);
199 }
200 
DeepUsersSearch(const AnfNodePtr & root,const IncludeFunc & include,const FuncGraphManagerPtr & mng)201 std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include,
202                                         const FuncGraphManagerPtr &mng) {
203   return DeepUsersSearcher(include, mng).Search(root);
204 }
205 }  // namespace mindspore
206