• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_
18 #define MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <map>
24 #include <set>
25 #include <utility>
26 #include "src/common/log_adapter.h"
27 #include "include/errorcode.h"
28 #include "ir/anf.h"
29 #include "ir/func_graph.h"
30 
31 namespace mindspore::lite {
32 class SubGraph;
33 using SubGraphPtr = std::shared_ptr<SubGraph>;
34 class SubGraph {
35  public:
36   explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "")
belong_anf_(std::move (belong_anf))37       : belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {}
38 
39   virtual ~SubGraph() = default;
40   int Init(const std::set<CNodePtr> &head_nodes = {});
41   int Reset(const std::set<CNodePtr> &nodes, const std::set<CNodePtr> &head_nodes = {});
42 
43   bool MergeSubGraph(const SubGraphPtr &subgraph);
44 
45   std::set<CNodePtr> GetNodes() const;
46   std::set<CNodePtr> GetInCNodes() const;
47   std::set<CNodePtr> GetOutCNodes() const;
48 
49   int ApplySubGraph();
50 
51  private:
52   std::set<CNodePtr> GetInputCNodes() const;
53   std::set<CNodePtr> GetOutputCNodes() const;
54   // init subgraph methods
55   int InitSubGraphNode(const std::set<CNodePtr> &head_nodes);
56   int InitSubGraphInNode();
57   int InitSubGraphOutNode();
58   // merge subgraph methods
59   std::set<CNodePtr> FindCommonOutputs(const SubGraphPtr &subgraph) const;
60   bool IfDependOnSameNode(const SubGraphPtr &subgraph) const;
61   // apply subgraph methods
62   SubGraphPtr FindBeforeSubGraphInBelongAnf() const;
63   SubGraphPtr FindAfterSubGraphInBelongAnf() const;
64   int CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector<AnfNodePtr> *partial_inputs,
65                                         std::map<AnfNodePtr, AnfNodePtr> *partial_inputs_and_subgraph_input_map);
66   int CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph,
67                                     const std::map<AnfNodePtr, AnfNodePtr> &partial_inputs_and_subgraph_input_map);
68   int CreatePartialInBelongAnf();
69   static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set<CNodePtr> &outputs);
70 
71  private:
72   std::set<CNodePtr> nodes_;
73   std::set<CNodePtr> in_nodes_;
74   std::set<CNodePtr> out_nodes_;
75   const FuncGraphPtr belong_anf_ = nullptr;
76   const std::string name_;
77 };
78 }  // namespace mindspore::lite
79 #endif  // MINDSPORE_LITE_TOOLS_COMMON_FUNC_GRAPH_SUBGRAPH_H_
80