• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CORE_IR_GRAPH_UTILS_H_
20 #define MINDSPORE_CORE_IR_GRAPH_UTILS_H_
21 
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <memory>
26 #include <vector>
27 #include <map>
28 #include <set>
29 #include <string>
30 
31 #include "ir/anf.h"
32 #include "ir/primitive.h"
33 #include "ir/scalar.h"
34 #include "ir/tensor.h"
35 #include "utils/label.h"
36 
37 namespace mindspore {
38 enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE };
39 
40 using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>;
41 using FilterFunc = std::function<bool(const AnfNodePtr &)>;
42 using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
43 using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
44 using MatchFunc = std::function<bool(const CNodePtr &)>;
45 
46 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
47 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
48 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include);
49 
50 std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node);
51 std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node);
52 std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node);
53 std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node);
54 
55 const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node);
56 
57 IncludeType AlwaysInclude(const AnfNodePtr &node);
58 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node);
59 
60 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
61 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
62 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
63 
64 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
65                                                         const FilterFunc &filter);
66 
67 class FuncGraphManager;
68 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
69 std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include,
70                                         const FuncGraphManagerPtr &mng);
71 std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
72                                  const IncludeFunc &include = AlwaysInclude);
73 
74 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts);
75 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root);
76 
77 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate);
78 
79 class FuncGraphIndex {
80  public:
81   explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch,
82                           const IncludeFunc &include = AlwaysInclude);
83   FuncGraphIndex(const FuncGraphIndex &) = delete;
84   FuncGraphIndex &operator=(const FuncGraphIndex &) = delete;
85 
~FuncGraphIndex()86   virtual ~FuncGraphIndex() {}
87 
88   std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key);
89   std::set<AnfNodePtr> GetNodes(const std::string &key);
90   FuncGraphPtr GetFirstFuncGraph(const std::string &key);
91   AnfNodePtr GetFirstNode(const std::string &key);
92 
93  private:
94   void Acquire(const FuncGraphPtr &key);
95   void Acquire(const AnfNodePtr &key);
96 
97   std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_;
98   std::map<std::string, std::set<AnfNodePtr>> index_node_;
99 };
100 }  // namespace mindspore
101 
102 #endif  // MINDSPORE_CORE_IR_GRAPH_UTILS_H_
103