1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_IR_GRAPH_UTILS_H_ 20 #define MINDSPORE_CORE_IR_GRAPH_UTILS_H_ 21 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <utility> 25 #include <memory> 26 #include <vector> 27 #include <map> 28 #include <set> 29 #include <string> 30 31 #include "ir/anf.h" 32 #include "ir/primitive.h" 33 #include "ir/scalar.h" 34 #include "ir/tensor.h" 35 #include "utils/label.h" 36 37 namespace mindspore { 38 enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; 39 40 using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>; 41 using FilterFunc = std::function<bool(const AnfNodePtr &)>; 42 using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>; 43 using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>; 44 using MatchFunc = std::function<bool(const CNodePtr &)>; 45 46 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); 47 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); 48 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); 49 50 std::vector<AnfNodePtr> SuccDeeper(const AnfNodePtr &node); 51 std::vector<AnfNodePtr> SuccDeeperSimple(const AnfNodePtr &node); 52 std::vector<AnfNodePtr> SuccIncoming(const AnfNodePtr &node); 53 std::vector<AnfNodePtr> SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); 54 55 const std::vector<AnfNodePtr> &GetInputs(const AnfNodePtr &node); 56 57 IncludeType AlwaysInclude(const AnfNodePtr &node); 58 IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); 59 60 std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); 61 std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); 62 std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); 63 64 std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, 65 const FilterFunc &filter); 66 67 class FuncGraphManager; 68 using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>; 69 std::vector<AnfNodePtr> DeepUsersSearch(const AnfNodePtr &root, const IncludeFunc &include, 70 const FuncGraphManagerPtr &mng); 71 std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, 72 const IncludeFunc &include = AlwaysInclude); 73 74 std::vector<CNodePtr> BroadFirstSearchGraphCNodes(const std::vector<CNodePtr> &starts); 75 std::vector<FuncGraphPtr> BroadFirstSearchGraphUsed(const FuncGraphPtr &root); 76 77 CNodePtr BroadFirstSearchFirstOf(const std::vector<CNodePtr> &starts, const MatchFunc &match_predicate); 78 79 class FuncGraphIndex { 80 public: 81 explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, 82 const IncludeFunc &include = AlwaysInclude); 83 FuncGraphIndex(const FuncGraphIndex &) = delete; 84 FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; 85 ~FuncGraphIndex()86 virtual ~FuncGraphIndex() {} 87 88 std::set<FuncGraphPtr> GetFuncGraphs(const std::string &key); 89 std::set<AnfNodePtr> GetNodes(const std::string &key); 90 FuncGraphPtr GetFirstFuncGraph(const std::string &key); 91 AnfNodePtr GetFirstNode(const std::string &key); 92 93 private: 94 void Acquire(const FuncGraphPtr &key); 95 void Acquire(const AnfNodePtr &key); 96 97 std::map<std::string, std::set<FuncGraphPtr>> index_func_graph_; 98 std::map<std::string, std::set<AnfNodePtr>> index_node_; 99 }; 100 } // namespace mindspore 101 102 #endif // MINDSPORE_CORE_IR_GRAPH_UTILS_H_ 103