1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ir/graph_utils.h"
18 #include "utils/hash_map.h"
19 #include "utils/hash_set.h"
20 #include "ir/visitor.h"
21 #include "ir/func_graph.h"
22 #include "utils/label.h"
23
24 namespace mindspore {
25 namespace {
26 class DeepFirstSearcher : public AnfIrVisitor {
27 public:
DeepFirstSearcher(const IncludeFunc & include,const FilterFunc & filter=nullptr)28 explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
29 : include_(include), filter_(filter) {
30 constexpr size_t kVecReserve = 64;
31 res_.reserve(kVecReserve);
32 }
33 ~DeepFirstSearcher() override = default;
34
Search(const AnfNodePtr & root)35 std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
36 if (root == nullptr) {
37 return std::move(res_);
38 }
39 seen_ = NewSeenGeneration();
40 Visit(root);
41 return std::move(res_);
42 }
43
Visit(const AnfNodePtr & node)44 void Visit(const AnfNodePtr &node) override {
45 if (node == nullptr || node->seen_ == seen_) {
46 return;
47 }
48 node->seen_ = seen_;
49 auto incl = include_(node);
50 if (incl == EXCLUDE) {
51 return;
52 }
53 if (filter_ == nullptr || !filter_(node)) {
54 res_.push_back(node);
55 }
56 if (incl == FOLLOW) {
57 AnfIrVisitor::Visit(node);
58 }
59 }
60
61 private:
62 SeenNum seen_{0};
63 IncludeFunc include_;
64 FilterFunc filter_;
65 std::vector<AnfNodePtr> res_{};
66 };
67
68 class DeepScopedGraphSearcher : public DeepFirstSearcher {
69 public:
DeepScopedGraphSearcher(const IncludeFunc & include)70 explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
71 ~DeepScopedGraphSearcher() override = default;
72
Visit(const CNodePtr & cnode)73 void Visit(const CNodePtr &cnode) override {
74 auto fg = cnode->func_graph();
75 if (fg == nullptr) {
76 return;
77 }
78 AnfNodePtr ret = fg->return_node();
79 DeepFirstSearcher::Visit(ret);
80
81 auto &inputs = cnode->inputs();
82 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
83 DeepFirstSearcher::Visit(*iter);
84 }
85 }
86
Visit(const ValueNodePtr & vnode)87 void Visit(const ValueNodePtr &vnode) override {
88 if (!IsValueNode<FuncGraph>(vnode)) {
89 return;
90 }
91 auto fg = GetValuePtr<FuncGraph>(vnode);
92 const auto &ret = fg->return_node();
93 DeepFirstSearcher::Visit(ret);
94 }
95
Visit(const ParameterPtr & param)96 void Visit(const ParameterPtr ¶m) override {
97 auto fg = param->func_graph();
98 if (fg == nullptr) {
99 return;
100 }
101 AnfNodePtr ret = fg->return_node();
102 DeepFirstSearcher::Visit(ret);
103 }
104 };
105
106 class DeepLinkedGraphSearcher : public DeepFirstSearcher {
107 public:
DeepLinkedGraphSearcher(const IncludeFunc & include)108 explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
109 ~DeepLinkedGraphSearcher() override = default;
110
Visit(const CNodePtr & cnode)111 void Visit(const CNodePtr &cnode) override {
112 auto &inputs = cnode->inputs();
113 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
114 DeepFirstSearcher::Visit(*iter);
115 }
116 }
117
Visit(const ValueNodePtr &)118 void Visit(const ValueNodePtr &) override {}
119 };
120 } // namespace
121
122 // include for if expand the node the search, filter for if put the node to results.
DeepScopedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)123 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
124 return DeepScopedGraphSearcher(include).Search(root);
125 }
126
DeepScopedGraphSearchWithFilter(const AnfNodePtr & root,const IncludeFunc & include,const FilterFunc & filter)127 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
128 const FilterFunc &filter) {
129 return DeepFirstSearcher(include, filter).Search(root);
130 }
131
DeepLinkedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)132 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
133 return DeepLinkedGraphSearcher(include).Search(root);
134 }
135 } // namespace mindspore
136