• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_SUB_GRAPH_SPLIT_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_SUB_GRAPH_SPLIT_H_
19 
20 #include <stack>
21 #include <vector>
22 #include <map>
23 #include <set>
24 #include <unordered_map>
25 #include "include/model.h"
26 #include "src/executor/kernel_exec.h"
27 #include "src/litert/lite_model.h"
28 #include "src/litert/inner_context.h"
29 #include "src/common/prim_util.h"
30 #include "nnacl/conv_parameter.h"
31 
32 namespace mindspore::lite {
33 constexpr int kDefaultSubGraphSize = 2;
34 constexpr int kDefaultFirstSubgraph = 0;
35 constexpr int kDefaultSecondSubgraph = 1;
36 constexpr int kDefaultInputs = 1;
37 constexpr int kMaxMultyInNode = 20;
38 constexpr int kMaxSubGraphCount = 10;
39 constexpr int kMinSubgraphCost = 50;
40 constexpr double kDefaultGpu = 0.5;
41 class SearchSubGraph {
42  public:
43   enum TensorType { NORMAL, CONSTANT, INPUT };
44 
45   struct Tensor {
46     std::vector<uint32_t> in_nodes_; /* used current tensor as input */
47     std::vector<uint32_t> out_nodes_;
48     TensorType type_;
49   };
50 
51   struct CostModel {
52     size_t mul_cost_ = 0;
53     size_t io_cost_ = 0;
54 
55     CostModel operator+(const SearchSubGraph::CostModel &cost) {
56       CostModel result;
57       result.mul_cost_ = this->mul_cost_ + cost.mul_cost_;
58       result.io_cost_ = this->io_cost_ + cost.io_cost_;
59       return result;
60     }
61     CostModel operator-(const SearchSubGraph::CostModel &cost) {
62       CostModel result;
63       result.mul_cost_ = this->mul_cost_ - cost.mul_cost_;
64       result.io_cost_ = this->io_cost_ - cost.io_cost_;
65       return result;
66     }
costCostModel67     int cost() { return io_cost_ + mul_cost_; }
emptyCostModel68     void empty() {
69       io_cost_ = 0;
70       mul_cost_ = 0;
71     }
72   };
73 
74   struct Subgraph {
75     std::vector<uint32_t> nodes_;
76     std::vector<uint32_t> heads_;
77     std::vector<uint32_t> ends_;
78     bool search_terminate_ = false;
79     DeviceType device_;
80     size_t thread_;
81     CostModel cost_;
82     uint32_t tid_; /* 1 or 2 */
83   };
84 
85  public:
86   SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
87                  const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes);
88   ~SearchSubGraph() = default;
89 
90  public:
91   void SubGraphSplit();
92   void SubGraphSplitByOperator();
93   void InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector<size_t> *outputs);
94 
95  private: /* split by output */
96   void SubGraphSplitByOutput();
97   void InitSearchSubGraphByOutput();
98   void InsertNode(uint32_t index, Subgraph *subgraph, uint32_t last_index);
99 
100  private: /* split by middle */
101   void SubGraphSplitByMiddle();
102   void InitSearchSubGraphByMiddle();
103   void SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes);
104   void InitMiddleSubgraph(const std::vector<uint32_t> *multy_in_nodes);
105   void InsertNodeByMid(uint32_t node_index, Subgraph *subgraph, uint32_t last_index);
106   void InsertHeadNode(uint32_t index, Subgraph *subgraph);
107   void OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index);
108 
109  private: /* split by offline */
110   void SubGraphSplitByOffLineParallel();
111   void UpdateOfflineParallelFlag();
112   bool CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs);
113 
114  private: /* public graph func  */
115   void RemoveConstNode(std::vector<uint32_t> *nodes);
116   void InitSearchTensor();
117   void InitMainGraphDevice(DeviceType dt = DT_CPU);
118 
119   void InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs);
120   void SubgraphFusion(std::vector<Subgraph> *sub_graphs);
121   void CalculateCostModel(std::vector<Subgraph> *sub_graphs);
122   void ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs);
123   bool ValidInParallel();
124   void CheckSubHeadEnd(Subgraph *sub);
125 
126  private: /* public schema func  */
127   void InsertParallelNode(uint32_t index, Subgraph *subgraph);
128   bool IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes);
129   bool IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
130                                   uint32_t root_node_index);
131   const schema::Primitive *CreatePartialPrimitive(int64_t subgraph_index);
132 
133  private: /* public cost-model func  */
134   CostModel CalculateConv2DFusion(const LiteGraph::Node *node);
135   void dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
136            std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs);
137 
138  public:
139   const InnerContext *context_ = nullptr;
140   LiteModel *model_ = nullptr;
141   std::vector<Tensor> tensors_;
142   std::vector<lite::Tensor *> *src_tensors_ = nullptr;
143 
144  private:
145   std::vector<size_t> *output_nodes_ = nullptr;
146   const std::map<int, OpParameter *> *op_parameters_ = nullptr;
147   std::vector<Subgraph> sub_graphs_;
148   std::unordered_map<uint32_t, std::vector<Subgraph>> node_sub_map_;
149   std::vector<LiteGraph::Node *> node_list_;
150   DeviceType major_dt_;
151   DeviceType minor_dt_;
152   size_t major_thread_;
153   size_t minor_thread_;
154   size_t total_cost_ = 0;
155   bool offline_parallel_enable_ = false;
156 };
157 }  // namespace mindspore::lite
158 
159 #endif  // MINDSPORE_LITE_SRC_RUNTIME_SUB_GRAPH_SPLIT_H_
160