1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_ 18 #define MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_ 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include <map> 24 #include <set> 25 #include <utility> 26 #include "src/common/log_adapter.h" 27 #include "include/errorcode.h" 28 #include "ir/anf.h" 29 #include "ir/func_graph.h" 30 31 namespace mindspore::lite { 32 class SubGraph; 33 using SubGraphPtr = std::shared_ptr<SubGraph>; 34 class SubGraph { 35 public: 36 explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "") belong_anf_(std::move (belong_anf))37 : belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {} 38 39 virtual ~SubGraph() = default; 40 int Init(const std::set<CNodePtr> &head_nodes = {}); 41 int Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {}); 42 43 bool MergeSubGraph(const SubGraphPtr &subgraph); 44 45 std::set<CNodePtr> GetNodes() const; 46 std::set<CNodePtr> GetInCNodes() const; 47 std::set<CNodePtr> GetOutCNodes() const; 48 49 int ApplySubGraph(); 50 51 private: 52 std::set<CNodePtr> GetInputCNodes() const; 53 std::set<CNodePtr> GetOutputCNodes() const; 54 // init subgraph methods 55 int InitSubGraphNode(const std::set<CNodePtr> &head_nodes); 56 int InitSubGraphInNode(); 57 int InitSubGraphOutNode(); 58 // merge subgraph methods 59 std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const; 60 bool IfDependOnSameNode(const SubGraphPtr &subgraph) const; 61 // apply subgraph methods 62 SubGraphPtr FindBeforeSubGraphInBelongAnf() const; 63 SubGraphPtr FindAfterSubGraphInBelongAnf() const; 64 int CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs, 65 std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map); 66 int CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph, 67 const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map); 68 int CreatePartialInBelongAnf(); 69 static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs); 70 71 private: 72 std::set<CNodePtr> nodes_; 73 std::set<CNodePtr> in_nodes_; 74 std::set<CNodePtr> out_nodes_; 75 const FuncGraphPtr belong_anf_ = nullptr; 76 const std::string name_; 77 }; 78 } // namespace mindspore::lite 79 #endif // MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_ 80