• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ir/graph_utils.h"
18 #include "utils/hash_map.h"
19 #include "utils/hash_set.h"
20 #include "ir/visitor.h"
21 #include "ir/func_graph.h"
22 #include "utils/label.h"
23 
24 namespace mindspore {
25 namespace {
26 class DeepFirstSearcher : public AnfIrVisitor {
27  public:
DeepFirstSearcher(const IncludeFunc & include,const FilterFunc & filter=nullptr)28   explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
29       : include_(include), filter_(filter) {
30     constexpr size_t kVecReserve = 64;
31     res_.reserve(kVecReserve);
32   }
33   ~DeepFirstSearcher() override = default;
34 
Search(const AnfNodePtr & root)35   std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
36     if (root == nullptr) {
37       return std::move(res_);
38     }
39     seen_ = NewSeenGeneration();
40     Visit(root);
41     return std::move(res_);
42   }
43 
Visit(const AnfNodePtr & node)44   void Visit(const AnfNodePtr &node) override {
45     if (node == nullptr || node->seen_ == seen_) {
46       return;
47     }
48     node->seen_ = seen_;
49     auto incl = include_(node);
50     if (incl == EXCLUDE) {
51       return;
52     }
53     if (filter_ == nullptr || !filter_(node)) {
54       res_.push_back(node);
55     }
56     if (incl == FOLLOW) {
57       AnfIrVisitor::Visit(node);
58     }
59   }
60 
61  private:
62   SeenNum seen_{0};
63   IncludeFunc include_;
64   FilterFunc filter_;
65   std::vector<AnfNodePtr> res_{};
66 };
67 
68 class DeepScopedGraphSearcher : public DeepFirstSearcher {
69  public:
DeepScopedGraphSearcher(const IncludeFunc & include)70   explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
71   ~DeepScopedGraphSearcher() override = default;
72 
Visit(const CNodePtr & cnode)73   void Visit(const CNodePtr &cnode) override {
74     auto fg = cnode->func_graph();
75     if (fg == nullptr) {
76       return;
77     }
78     AnfNodePtr ret = fg->return_node();
79     DeepFirstSearcher::Visit(ret);
80 
81     auto &inputs = cnode->inputs();
82     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
83       DeepFirstSearcher::Visit(*iter);
84     }
85   }
86 
Visit(const ValueNodePtr & vnode)87   void Visit(const ValueNodePtr &vnode) override {
88     if (!IsValueNode<FuncGraph>(vnode)) {
89       return;
90     }
91     auto fg = GetValuePtr<FuncGraph>(vnode);
92     const auto &ret = fg->return_node();
93     DeepFirstSearcher::Visit(ret);
94   }
95 
Visit(const ParameterPtr & param)96   void Visit(const ParameterPtr &param) override {
97     auto fg = param->func_graph();
98     if (fg == nullptr) {
99       return;
100     }
101     AnfNodePtr ret = fg->return_node();
102     DeepFirstSearcher::Visit(ret);
103   }
104 };
105 
106 class DeepLinkedGraphSearcher : public DeepFirstSearcher {
107  public:
DeepLinkedGraphSearcher(const IncludeFunc & include)108   explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {}
109   ~DeepLinkedGraphSearcher() override = default;
110 
Visit(const CNodePtr & cnode)111   void Visit(const CNodePtr &cnode) override {
112     auto &inputs = cnode->inputs();
113     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
114       DeepFirstSearcher::Visit(*iter);
115     }
116   }
117 
Visit(const ValueNodePtr &)118   void Visit(const ValueNodePtr &) override {}
119 };
120 }  // namespace
121 
122 // include for if expand the node the search, filter for if put the node to results.
DeepScopedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)123 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
124   return DeepScopedGraphSearcher(include).Search(root);
125 }
126 
DeepScopedGraphSearchWithFilter(const AnfNodePtr & root,const IncludeFunc & include,const FilterFunc & filter)127 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
128                                                         const FilterFunc &filter) {
129   return DeepFirstSearcher(include, filter).Search(root);
130 }
131 
DeepLinkedGraphSearch(const AnfNodePtr & root,const IncludeFunc & include)132 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
133   return DeepLinkedGraphSearcher(include).Search(root);
134 }
135 }  // namespace mindspore
136