1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/graph_utils.h"
18
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <utility>
22 #include <stack>
23 #include <vector>
24 #include <list>
25 #include <string>
26 #include <fstream>
27
28 #include "ir/visitor.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph.h"
31 #include "utils/label.h"
32 #include "utils/log_adapter.h"
33 #include "utils/ms_utils.h"
34
35 namespace mindspore {
36 namespace {
37 class DeepFirstSearcher : public AnfIrVisitor {
38 public:
DeepFirstSearcher(const IncludeFunc & include,const FilterFunc & filter=nullptr)39 explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
40 : include_(include), filter_(filter) {}
41 ~DeepFirstSearcher() override = default;
42
Search(const AnfNodePtr & root)43 std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
44 if (root == nullptr) {
45 return std::move(res_);
46 }
47 seen_ = NewSeenGeneration();
48 Visit(root);
49 return std::move(res_);
50 }
51
Visit(const AnfNodePtr & node)52 void Visit(const AnfNodePtr &node) override {
53 MS_EXCEPTION_IF_NULL(node);
54 if (node->seen_ == seen_) {
55 return;
56 }
57
58 node->seen_ = seen_;
59
60 auto incl = include_(node);
61 if (incl == EXCLUDE) {
62 return;
63 }
64 if (filter_ == nullptr || !filter_(node)) {
65 res_.push_back(node);
66 }
67 if (incl == FOLLOW) {
68 AnfIrVisitor::Visit(node);
69 }
70 }
71
72 private:
73 size_t seen_{0};
74 IncludeFunc include_;
75 FilterFunc filter_;
76 std::vector<AnfNodePtr> res_{};
77 };
78
79 class DeepScopedGraphSearcher : public DeepFirstSearcher {
80 public:
DeepScopedGraphSearcher(const IncludeFunc & include)81 explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
82 ~DeepScopedGraphSearcher() override = default;
83
Visit(const CNodePtr & cnode)84 void Visit(const CNodePtr &cnode) override {
85 if (cnode->func_graph() == nullptr) {
86 return;
87 }
88
89 AnfNodePtr ret = cnode->func_graph()->get_return();
90 if (ret != nullptr) {
91 DeepFirstSearcher::Visit(ret);
92 }
93
94 auto &inputs = cnode->inputs();
95 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
96 DeepFirstSearcher::Visit(*iter);
97 }
98 }
99
Visit(const ValueNodePtr & vnode)100 void Visit(const ValueNodePtr &vnode) override {
101 if (!IsValueNode<FuncGraph>(vnode)) {
102 return;
103 }
104
105 auto graph = GetValueNode<FuncGraphPtr>(vnode);
106 AnfNodePtr ret = graph->get_return();
107 if (ret != nullptr) {
108 DeepFirstSearcher::Visit(ret);
109 }
110 }
111
Visit(const ParameterPtr & param)112 void Visit(const ParameterPtr ¶m) override {
113 if (param->func_graph() == nullptr) {
114 return;
115 }
116
117 AnfNodePtr ret = param->func_graph()->get_return();
118 if (ret != nullptr) {
119 DeepFirstSearcher::Visit(ret);
120 }
121 }
122 };
123
124 class DeepUsedGraphSearcher : public DeepFirstSearcher {
125 public:
DeepUsedGraphSearcher(const IncludeFunc & include)126 explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
127 ~DeepUsedGraphSearcher() override = default;
128
Visit(const CNodePtr & cnode)129 void Visit(const CNodePtr &cnode) override {
130 auto &inputs = cnode->inputs();
131 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
132 DeepFirstSearcher::Visit(*iter);
133 }
134 }
135
Visit(const ValueNodePtr & vnode)136 void Visit(const ValueNodePtr &vnode) override {
137 if (!IsValueNode<FuncGraph>(vnode)) {
138 return;
139 }
140
141 auto graph = GetValueNode<FuncGraphPtr>(vnode);
142 AnfNodePtr ret = graph->get_return();
143 if (ret != nullptr) {
144 DeepFirstSearcher::Visit(ret);
145 }
146 }
147 };
148
149 class DeepLinkedGraphSearcher : public DeepFirstSearcher {
150 public:
DeepLinkedGraphSearcher(const IncludeFunc & include)151 explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
152 ~DeepLinkedGraphSearcher() override = default;
153
Visit(const CNodePtr & cnode)154 void Visit(const CNodePtr &cnode) override {
155 auto &inputs = cnode->inputs();
156 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
157 DeepFirstSearcher::Visit(*iter);
158 }
159 }
160
Visit(const ValueNodePtr &)161 void Visit(const ValueNodePtr &) override {}
162 };
163
164 class DeepUsersSearcher : public DeepFirstSearcher {
165 public:
DeepUsersSearcher(const IncludeFunc & include,const FuncGraphManagerPtr & mng)166 explicit DeepUsersSearcher(const IncludeFunc &include, const FuncGraphManagerPtr &mng)
167 : DeepFirstSearcher(include), mng_(mng) {}
168 ~DeepUsersSearcher() override = default;
169
Visit(const CNodePtr & cnode)170 void Visit(const CNodePtr &cnode) override {
171 auto &users = mng_->node_users()[cnode];
172 for (auto iter = users.begin(); iter != users.end(); ++iter) {
173 DeepFirstSearcher::Visit(iter->first);
174 }
175 }
Visit(const ValueNodePtr &)176 void Visit(const ValueNodePtr &) override {}
177
178 private:
179 FuncGraphManagerPtr mng_;
180 };
181 } // namespace
182
183 // include for if expand the node the search, filter for if put the node to results.
DeepScopedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)184 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
185 return DeepScopedGraphSearcher(include).Search(root);
186 }
187
DeepScopedGraphSearchWithFilter(const AnfNodePtr & root,const IncludeFunc & include,const FilterFunc & filter)188 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
189 const FilterFunc &filter) {
190 return DeepFirstSearcher(include, filter).Search(root);
191 }
192
DeepUsedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)193 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
194 return DeepUsedGraphSearcher(include).Search(root);
195 }
196
DeepLinkedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)197 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
198 return DeepLinkedGraphSearcher(include).Search(root);
199 }
200
DeepUsersSearch(const AnfNodePtr & root,const IncludeFunc & include,const FuncGraphManagerPtr & mng)201 std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include,
202 const FuncGraphManagerPtr &mng) {
203 return DeepUsersSearcher(include, mng).Search(root);
204 }
205 } // namespace mindspore
206