• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/sub_graph_split.h"
18 #include <stdlib.h>
19 #include <utility>
20 #include <algorithm>
21 #include <iterator>
22 #include <vector>
23 #include "src/tensor.h"
24 #include "schema/ops_generated.h"
25 #include "schema/model_generated.h"
26 #include "src/ops/populate/populate_register.h"
27 #include "src/scheduler.h"
28 #include "nnacl/pooling_parameter.h"
29 #include "include/model.h"
30 #include "nnacl/base/conv_common_base.h"
31 
32 namespace mindspore::lite {
CommConvMul(std::vector<int> weight_shape,std::vector<int> output_shape)33 size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
34   size_t cost = output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] * output_shape[NHWC_C] *
35                 weight_shape[NHWC_H] * weight_shape[NHWC_W] * weight_shape[NHWC_C];
36   return cost;
37 }
38 
WinogradConvMul()39 size_t WinogradConvMul() {
40   /* winograd conv */
41   return 0;
42 }
43 
CommConvdwMul(std::vector<int> weight_shape,std::vector<int> output_shape)44 size_t CommConvdwMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
45   size_t cost = static_cast<size_t>(output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] *
46                                     output_shape[NHWC_C] * weight_shape[NHWC_H] * weight_shape[NHWC_W]);
47   return cost;
48 }
49 
WinogradConvDwMul()50 size_t WinogradConvDwMul() {
51   /* winograd convdw */
52   return 0;
53 }
54 
IsOfflineParallelNode(const void * node_primitive,int node_device_type)55 bool IsOfflineParallelNode(const void *node_primitive, int node_device_type) {
56   if (node_primitive == nullptr) {
57     return false;
58   }
59   return (GetPrimitiveType(node_primitive, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) &&
60          (node_device_type != kDefaultDeviceType);
61 }
62 
UpdateOfflineParallelFlag()63 void SearchSubGraph::UpdateOfflineParallelFlag() {
64   if (model_ == nullptr) {
65     offline_parallel_enable_ = false;
66     return;
67   }
68   // visited whole models to find any conv && depthwise conv have been set to device type
69   offline_parallel_enable_ = std::any_of(
70     this->model_->graph_.all_nodes_.begin(), this->model_->graph_.all_nodes_.end(),
71     [&](lite::LiteGraph::Node *node) { return IsOfflineParallelNode(node->primitive_, node->device_type_); });
72 }
73 
CheckIsParallelSubGraph(const std::vector<Subgraph> & subgraphs)74 bool SearchSubGraph::CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs) {
75   if (subgraphs.size() != kDefaultSubGraphSize) {
76     return false;
77   }
78 
79   for (const auto &sub_graph : subgraphs) {
80     auto heads = sub_graph.heads_;
81     auto ends = sub_graph.ends_;
82     if (heads.size() != kDefaultInputs || ends.size() != kDefaultInputs) {
83       return false;
84     }
85     auto head_node = model_->graph_.all_nodes_.at(heads.front());
86     auto end_node = model_->graph_.all_nodes_.at(ends.front());
87     if (!IsOfflineParallelNode(head_node->primitive_, head_node->device_type_) ||
88         !IsOfflineParallelNode(end_node->primitive_, end_node->device_type_)) {
89       return false;
90     }
91 
92     // 1. check head_node's input is SplitOverlap node
93     for (const auto &input : head_node->input_indices_) {
94       if (tensors_.at(input).type_ == CONST) {
95         continue;
96       }
97       auto input_node_index = tensors_.at(input).out_nodes_.front();
98       if (GetPrimitiveType(model_->graph_.all_nodes_.at(input_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
99           schema::PrimitiveType_SplitWithOverlap) {
100         return false;
101       }
102     }
103 
104     // 2. check end_node's output is concat node
105     for (const auto &output : end_node->output_indices_) {
106       if (tensors_.at(output).type_ == CONST) {
107         continue;
108       }
109       auto output_node_index = tensors_.at(output).in_nodes_.front();
110       if (GetPrimitiveType(model_->graph_.all_nodes_.at(output_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
111           schema::PrimitiveType_Concat) {
112         return false;
113       }
114     }
115   }
116   return true;
117 }
118 
dfs(int i,int n,int current_sum,int except_value,int * min_value,std::vector<bool> * tmp_group,std::vector<bool> * cor_group,std::vector<Subgraph> * sub_graphs)119 void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
120                          std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
121   if (i == n) {
122     if (abs(except_value - current_sum) < *min_value) {
123       for (int j = 0; j < n; j++) {
124         cor_group->at(j) = tmp_group->at(j);
125       }
126     }
127     *min_value = MSMIN(*min_value, abs(except_value - current_sum));
128     return;
129   }
130 
131   {
132     tmp_group->at(i) = true;
133     int next_sum = current_sum + sub_graphs->at(i).cost_.cost();
134     dfs(i + 1, n, next_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
135   }
136 
137   {
138     tmp_group->at(i) = false;
139     dfs(i + 1, n, current_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
140   }
141   return;
142 }
143 
CalculateConv2DFusion(LiteGraph::Node * node)144 SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(LiteGraph::Node *node) {
145   CostModel cost;
146   std::vector<uint32_t> inputs = node->input_indices_;
147   std::vector<uint32_t> outputs = node->output_indices_;
148 
149   std::vector<int> weight_shape = src_tensors_->at(inputs[1])->shape();
150   std::vector<int> output_shape = src_tensors_->at(outputs[0])->shape();
151 
152   ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameters_->at(outputs[0]));
153 
154   if (param->group_ == 1) {
155     if (param->kernel_h_ == 1 && param->kernel_w_ == 1) {
156       size_t conv1x1_mul_cost = CommConvMul(weight_shape, output_shape);
157       cost.mul_cost_ += conv1x1_mul_cost;
158     } else {
159       int out_unit;
160       if (CheckIfUseWinograd(&out_unit, param)) {
161         size_t winograd_conv_cost = CommConvMul(weight_shape, output_shape);
162         cost.mul_cost_ += winograd_conv_cost;
163       } else {
164         size_t comm_conv_mul_cost = CommConvMul(weight_shape, output_shape);
165         cost.mul_cost_ += comm_conv_mul_cost;
166       }
167     }
168   } else if (param->group_ == param->input_channel_ && param->group_ == param->output_channel_) {
169 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
170     if (CheckConvDw1DWinograd(param, context_->thread_num_)) {
171       /* ConvolutionDepthwise3x3CPUKernel */
172       size_t winograd_convdw_cost = CommConvdwMul(weight_shape, output_shape);
173       cost.mul_cost_ += winograd_convdw_cost;
174     } else {
175       /* ConvolutionDepthwiseIndirectCPUKernel */
176       /* ConvolutionDepthwiseSWCPUKernel */
177       /* ConvolutionDepthwiseCPUKernel */
178       size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
179       cost.mul_cost_ += comm_convdw_cost;
180     }
181 #else
182     size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
183     cost.mul_cost_ += comm_convdw_cost;
184 #endif
185   } else {
186     /* group conv */
187   }
188   return cost;
189 }
190 
CreatePartialPrimitive(int64_t subgraph_index)191 const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) {
192   flatbuffers::FlatBufferBuilder fbb(1024);
193   auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index);
194   auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o);
195   fbb.Finish(prim_offset);
196   auto tmp_buf = fbb.GetBufferPointer();
197   void *prim_buf = malloc(fbb.GetSize());
198   if (prim_buf == nullptr) {
199     return nullptr;
200   }
201   memcpy(prim_buf, tmp_buf, fbb.GetSize());
202 
203   auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf);
204   fbb.Clear();
205 
206   model_->node_bufs_.push_back(prim_buf);
207   return std::move(primitive);
208 }
209 
ConvertSubGraphToModel(std::vector<Subgraph> * sub_graphs)210 void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
211   if (sub_graphs->size() != kDefaultSubGraphSize) {
212     return;
213   }
214   LiteGraph::SubGraph *main_graphs = model_->graph_.sub_graphs_.front();
215 
216   for (Subgraph &subgraph : *sub_graphs) {
217     if (subgraph.nodes_.empty()) {
218       continue;
219     }
220 
221     DeviceType device_type = subgraph.device_;
222     size_t thread_num = subgraph.thread_;
223     int new_sub_index = static_cast<int>(model_->graph_.sub_graphs_.size());
224     int partial_index = static_cast<int>(model_->graph_.all_nodes_.size());
225     int particial_replace_index = partial_index;
226 
227     LiteGraph::SubGraph *new_sub_graph = new (std::nothrow) LiteGraph::SubGraph();
228     if (new_sub_graph == nullptr) {
229       MS_LOG(ERROR) << "New sub graph failed!";
230       return;
231     }
232     new_sub_graph->name_ = "SubSplit" + std::to_string(new_sub_index);
233     LiteGraph::Node *new_partial_node = new (std::nothrow) LiteGraph::Node();
234     if (new_partial_node == nullptr) {
235       MS_LOG(ERROR) << "New partial node failed!";
236       delete new_sub_graph;
237       return;
238     }
239     new_partial_node->name_ = "SubSplitPartial" + std::to_string(new_sub_index);
240     if (device_type == DT_CPU) {
241       new_partial_node->name_ = "Cpu" + new_partial_node->name_;
242     } else if (device_type == DT_GPU) {
243       new_partial_node->name_ = "Gpu" + new_partial_node->name_;
244     } else if (device_type == DT_NPU) {
245       new_partial_node->name_ = "Npu" + new_partial_node->name_;
246     }
247 
248     new_partial_node->node_type_ = static_cast<int>(mindspore::lite::MSNodeType::NodeType_ValueNode);
249     new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
250 
251     while (!subgraph.nodes_.empty()) {
252       uint32_t node_index = subgraph.nodes_.front();
253       LiteGraph::Node *cur_node = model_->graph_.all_nodes_[node_index];
254       new_sub_graph->node_indices_.push_back(node_index);
255 
256       auto iter = find(main_graphs->node_indices_.begin(), main_graphs->node_indices_.end(), node_index);
257       int cur_node_index = std::distance(std::begin(main_graphs->node_indices_), iter);
258       particial_replace_index = (cur_node_index < particial_replace_index) ? cur_node_index : particial_replace_index;
259 
260       VectorErase(&main_graphs->node_indices_, node_index);
261       VectorErase(&subgraph.nodes_, node_index);
262       cur_node->device_type_ = static_cast<int>(device_type);
263       op_parameters_->at(static_cast<int>(cur_node->output_indices_.at(0)))->thread_num_ = thread_num;
264     }
265 
266     for (uint32_t head_index : subgraph.heads_) {
267       LiteGraph::Node *head_node = model_->graph_.all_nodes_[head_index];
268       std::vector<uint32_t> inputs = head_node->input_indices_;
269       for (auto input : inputs) {
270         if (tensors_[input].type_ == CONST) {
271           continue;
272         }
273         if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) !=
274             new_sub_graph->input_indices_.end()) {
275           continue;
276         }
277 
278         auto in_tensor_in_nodes = tensors_[input].out_nodes_;
279         if (!in_tensor_in_nodes.empty()) {
280           uint32_t in_tensor_in_node = in_tensor_in_nodes[0];
281           if (std::find(new_sub_graph->node_indices_.begin(), new_sub_graph->node_indices_.end(), in_tensor_in_node) !=
282               new_sub_graph->node_indices_.end()) {
283             continue;
284           }
285         }
286 
287         new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input);
288         new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input);
289       }
290     }
291 
292     for (uint32_t end_index : subgraph.ends_) {
293       LiteGraph::Node *end_node = model_->graph_.all_nodes_[end_index];
294       std::vector<uint32_t> outputs = end_node->output_indices_;
295       new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end());
296       new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end());
297     }
298 
299     main_graphs->node_indices_.insert(main_graphs->node_indices_.begin() + particial_replace_index, partial_index);
300     model_->graph_.all_nodes_.push_back(std::move(new_partial_node));
301     model_->graph_.sub_graphs_.push_back(std::move(new_sub_graph));
302   }
303 
304   sub_graphs->clear();
305   return;
306 }
307 
IsNodeSubGraphHead(uint32_t node_index,const std::vector<uint32_t> & ready_nodes)308 bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
309   std::vector<uint32_t> output_indexes = model_->graph_.all_nodes_.at(node_index)->output_indices_;
310   std::vector<uint32_t> output_nodes;
311   for (uint32_t out_t : output_indexes) {
312     std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
313     output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
314   }
315   if (output_indexes.size() == 1 && output_nodes.size() == 1) {
316     return false;
317   }
318   for (uint32_t out_n : output_nodes) {
319     if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
320       return true;
321     }
322   }
323   return false;
324 }
325 
IsNodeSubGraphHeadWithRoot(uint32_t node_index,const std::vector<uint32_t> & ready_nodes,uint32_t root_node_index)326 bool SearchSubGraph::IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
327                                                 uint32_t root_node_index) {
328   std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_;
329   std::vector<uint32_t> output_nodes;
330   for (uint32_t out_t : output_indexes) {
331     std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
332     output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
333   }
334   for (uint32_t out_n : output_nodes) {
335     if (root_node_index != out_n) {
336       if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
337         return true;
338       }
339     }
340   }
341   return false;
342 }
343 
SearchMultyInNodes(std::vector<uint32_t> * multy_in_nodes)344 void SearchSubGraph::SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes) {
345   std::vector<uint32_t> all_main_sub_nodes = model_->graph_.sub_graphs_[0]->node_indices_;
346   for (size_t i = 0; i < all_main_sub_nodes.size(); i++) {
347     uint32_t node_index = all_main_sub_nodes[i];
348     LiteGraph::Node *node = node_list_[node_index];
349 
350     if (IsPartialNode(node->primitive_, model_->GetSchemaVersion())) {
351       continue;
352     }
353     int input_count = std::count_if(node->input_indices_.begin(), node->input_indices_.end(),
354                                     [&](uint32_t in_tensor_index) { return tensors_[in_tensor_index].type_ != CONST; });
355     if (input_count > 1) {
356       multy_in_nodes->push_back(node_index);
357     }
358   }
359   return;
360 }
361 
RemoveConstNode(std::vector<uint32_t> * nodes)362 void SearchSubGraph::RemoveConstNode(std::vector<uint32_t> *nodes) {
363   bool stop_search = false;
364   while (!stop_search) {
365     stop_search = true;
366     for (size_t i = 0; i < nodes->size(); i++) {
367       if (tensors_[nodes->at(i)].type_ == CONST) {
368         VectorErase(nodes, nodes->at(i));
369         stop_search = false;
370         break;
371       }
372     }
373   }
374 
375   return;
376 }
377 
InsertNode(uint32_t index,Subgraph * subgraph,uint32_t last_index)378 void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph, uint32_t last_index) {
379   if (subgraph->search_terminate_) {
380     return;
381   }
382 
383   LiteGraph::Node *node = node_list_.at(index);
384   if (node == nullptr) {
385     return;
386   }
387 
388   std::vector<uint32_t> input = node->input_indices_;
389   RemoveConstNode(&input);
390 
391   /* all node_input is graph_input */
392   for (size_t i = 0; i < input.size(); i++) {
393     if (tensors_[input[i]].type_ != INPUT) {
394       break;
395     }
396     subgraph->heads_.clear();
397     subgraph->ends_.clear();
398     subgraph->nodes_.clear();
399     subgraph->search_terminate_ = true;
400     return;
401   }
402 
403   /* split in graph */
404   if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
405     if (subgraph->nodes_.empty()) {
406       subgraph->search_terminate_ = true;
407       return;
408     }
409     subgraph->heads_.push_back(last_index);
410     return;
411   }
412 
413   if (find(output_nodes_->begin(), output_nodes_->end(), index) != output_nodes_->end()) {
414     subgraph->ends_.push_back(index);
415   }
416 
417   /* node insert in current subgraph */
418   subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
419   node_list_.at(index) = nullptr;
420 
421   /* search for next node */
422   for (uint32_t in : input) {
423     auto next_nodes = tensors_[in].out_nodes_;
424     for (uint32_t next_node : next_nodes) {
425       InsertNode(next_node, subgraph, index);
426     }
427   }
428   return;
429 }
430 
OptimizeAfterFusion(std::vector<Subgraph> * sub_graphs,uint32_t root_node_index)431 void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index) {
432   MS_ASSERT(sub_graphs->size() == kDefaultSubGraphSize);
433   for (Subgraph &sub : *sub_graphs) {
434     if (sub.nodes_.empty()) {
435       return;
436     }
437     int head_size = static_cast<int>(sub.heads_.size());
438     std::vector<uint32_t> used_heads;
439     for (int i = 0; i < head_size; i++) {
440       uint32_t head_node_index = sub.heads_.at(i);
441       if (std::find(used_heads.begin(), used_heads.end(), head_node_index) != used_heads.end()) {
442         break;
443       }
444       std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node_index]->input_indices_;
445       RemoveConstNode(&head_input_tensors);
446       if (head_input_tensors.size() != 1) continue;
447 
448       std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
449       if (input_nodes.size() != 1) continue;
450       uint32_t input_node_index = input_nodes.at(0);
451 
452       std::vector<uint32_t> input_tensors = model_->graph_.all_nodes_[input_node_index]->input_indices_;
453       RemoveConstNode(&input_tensors);
454       if (input_tensors.size() != 1) continue;
455 
456       /* this node qualified:
457        * 1. the only input node of current head node
458        * 2. all output included in current subgraph
459        * 3. one input-tensor */
460       if (!IsNodeSubGraphHeadWithRoot(input_node_index, sub.nodes_, root_node_index)) {
461         InsertHeadNode(input_node_index, &sub);
462         used_heads.push_back(head_node_index); /* delete used head at end */
463       }
464       head_size = static_cast<int>(sub.heads_.size());
465     }
466     for (auto head_index : used_heads) {
467       VectorErase(&sub.heads_, head_index);
468     }
469 
470     CheckSubHeadEnd(&sub);
471 
472     /* sort node index  */
473     std::sort(sub.nodes_.begin(), sub.nodes_.end());
474   }
475 }
476 
InsertHeadNode(uint32_t head_node_index,Subgraph * subgraph)477 void SearchSubGraph::InsertHeadNode(uint32_t head_node_index, Subgraph *subgraph) {
478   LiteGraph::Node *node = node_list_.at(head_node_index);
479   std::vector<uint32_t> head_node_inputs = node->input_indices_;
480   RemoveConstNode(&head_node_inputs);
481 
482   subgraph->nodes_.push_back(head_node_index);
483   node_list_.at(head_node_index) = nullptr;
484 
485   /* search for next node */
486   size_t current_node_size = subgraph->nodes_.size();
487   for (uint32_t in : head_node_inputs) {
488     auto next_nodes = tensors_[in].out_nodes_;
489     for (uint32_t next_node : next_nodes) {
490       InsertNodeByMid(next_node, subgraph, head_node_index);
491     }
492   }
493 
494   if (current_node_size == subgraph->nodes_.size()) {
495     subgraph->heads_.push_back(head_node_index);
496   }
497   return;
498 }
499 
InsertNodeByMid(uint32_t node_index,Subgraph * subgraph,uint32_t last_index)500 void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph, uint32_t last_index) {
501   LiteGraph::Node *node = node_list_.at(node_index);
502   MS_CHECK_PTR_IF_NULL(node);
503 
504   auto subs_iter = node_sub_map_.find(node_index);
505   if (subs_iter != node_sub_map_.end()) {
506     /* node is multy-in node , already searched before */
507 
508     if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
509       /* this node can not be included in this subgraph */
510       if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(last_index);
511       return;
512     }
513 
514     /* include this multy-in-unit in current subgraph */
515     std::vector<Subgraph> &subs = subs_iter->second;
516 
517     /* insert nodes */
518     subgraph->nodes_.push_back(node_index);
519     for (Subgraph &sub : subs) {
520       subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
521     }
522 
523     /* insert heads */
524     std::set<uint32_t> subs_head;
525     subs_head.insert(node_index);
526     for (Subgraph &sub : subs) {
527       for (uint32_t head : sub.heads_) {
528         subs_head.insert(head);
529       }
530     }
531 
532     std::set<uint32_t> subs_head_baklist = subs_head;
533     for (uint32_t head_node : subs_head) {
534       std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node]->input_indices_;
535       RemoveConstNode(&head_input_tensors);
536       if (head_input_tensors.size() != 1) continue;
537       std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
538       if (input_nodes.size() != 1) continue;
539 
540       uint32_t input_node = input_nodes.at(0);
541       if (!IsNodeSubGraphHead(input_node, subgraph->nodes_)) {
542         InsertNodeByMid(input_node, subgraph, head_node);
543         subs_head_baklist.erase(head_node);
544       }
545     }
546 
547     /* stop search  */
548     for (auto head : subs_head_baklist) {
549       subgraph->heads_.push_back(head);
550     }
551     node_sub_map_.erase(node_index);
552     return;
553   }
554 
555   std::vector<uint32_t> inputs = node->input_indices_;
556   RemoveConstNode(&inputs);
557 
558   if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
559     if (!subgraph->nodes_.empty()) {
560       if (std::find(subgraph->heads_.begin(), subgraph->heads_.end(), last_index) == subgraph->heads_.end()) {
561         subgraph->heads_.push_back(last_index);
562       }
563     }
564     return;
565   }
566 
567   subgraph->nodes_.insert(subgraph->nodes_.begin(), node_index);
568   node_list_.at(node_index) = nullptr;
569 
570   /* search for next node */
571   for (uint32_t in : inputs) {
572     auto next_nodes = tensors_[in].out_nodes_;
573     if (next_nodes.size() == 0) {
574       if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(subgraph->nodes_.front());
575     } else {
576       for (uint32_t next_node : next_nodes) {
577         InsertNodeByMid(next_node, subgraph, node_index);
578       }
579     }
580   }
581   return;
582 }
583 
InitMiddleSubgraph(std::vector<uint32_t> * multy_in_nodes)584 void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
585   for (uint32_t node_index : *multy_in_nodes) {
586     std::vector<Subgraph> node_subs;
587     LiteGraph::Node *node = node_list_[node_index];
588     for (uint32_t input_tensor_index : node->input_indices_) {
589       Tensor *tensor = &tensors_[input_tensor_index];
590       if (tensor->type_ == CONST || tensor->type_ == INPUT) continue;
591 
592       std::vector<uint32_t> input_nodes = tensor->out_nodes_;
593       if (input_nodes.empty()) continue;
594       if (input_nodes.size() != 1) continue;
595       uint32_t input_node = input_nodes[0];
596 
597       Subgraph sub;
598       sub.ends_.push_back(input_node);
599       InsertNodeByMid(input_node, &sub, input_node);
600       node_subs.push_back(sub);
601     }
602     if (!node_subs.empty()) {
603       node_sub_map_.insert(std::make_pair(node_index, node_subs));
604     }
605   }
606   return;
607 }
608 
InitSearchSubGraphByMiddle()609 void SearchSubGraph::InitSearchSubGraphByMiddle() {
610   sub_graphs_.clear();
611   node_list_ = model_->graph_.all_nodes_;
612 
613   std::vector<uint32_t> multy_in_nodes;
614 
615   SearchMultyInNodes(&multy_in_nodes);
616 
617   if (multy_in_nodes.size() > kMaxMultyInNode) {
618     node_sub_map_.clear();
619     return;
620   }
621 
622   InitMiddleSubgraph(&multy_in_nodes);
623 
624   if (node_sub_map_.size() > kMaxSubGraphCount) {
625     node_sub_map_.clear();
626   }
627   return;
628 }
629 
InitSearchSubGraphByOutput()630 void SearchSubGraph::InitSearchSubGraphByOutput() {
631   sub_graphs_.clear();
632   node_list_ = model_->graph_.all_nodes_;
633 
634   for (auto out : *output_nodes_) {
635     Subgraph subgraph;
636 
637     InsertNode(static_cast<uint32_t>(out), &subgraph, static_cast<uint32_t>(out));
638 
639     sub_graphs_.push_back(std::move(subgraph));
640   }
641   return;
642 }
643 
InitSearchTensor()644 void SearchSubGraph::InitSearchTensor() {
645   tensors_.resize(model_->graph_.all_tensors_.size());
646 
647   /* Set Tensor Type */
648   for (size_t i = 0; i < tensors_.size(); i++) {
649     tensors_[i].type_ = NORMAL;
650     mindspore::schema::Tensor *src_tensor = static_cast<schema::Tensor *>(model_->graph_.all_tensors_[i]);
651     auto category = TensorCategory(src_tensor);
652     if (category == mindspore::lite::Tensor::Category::CONST_TENSOR ||
653         category == mindspore::lite::Tensor::Category::CONST_SCALAR) {
654       tensors_[i].type_ = CONST;
655     }
656   }
657   std::vector<uint32_t> graph_input = model_->graph_.sub_graphs_[0]->input_indices_;
658   for (auto in : graph_input) {
659     tensors_[in].type_ = INPUT;
660   }
661 
662   /* Set Tensor In and out Node */
663   for (size_t index = 0; index < model_->graph_.all_nodes_.size(); index++) {
664     LiteGraph::Node *node = model_->graph_.all_nodes_[index];
665     std::vector<uint32_t> input = node->input_indices_;
666     for (uint32_t in : input) {
667       tensors_[in].in_nodes_.push_back(index);
668     }
669     std::vector<uint32_t> output = node->output_indices_;
670     for (uint32_t out : output) {
671       tensors_[out].out_nodes_.push_back(index);
672     }
673   }
674   return;
675 }
676 
InitSubgraphRuntimeInfo(std::vector<Subgraph> * sub_graphs)677 void SearchSubGraph::InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs) {
678   std::vector<bool> tmp_group;
679   std::vector<bool> cor_group;
680 
681   tmp_group.resize(sub_graphs->size());
682   cor_group.resize(sub_graphs->size());
683 
684   int except_value = static_cast<int>(total_cost_ * kDefaultGpu); /* major device responsible for 50% calculation */
685   int min_value = INT32_MAX;
686 
687   dfs(0, static_cast<int>(sub_graphs->size()), 0, except_value, &min_value, &tmp_group, &cor_group, sub_graphs);
688 
689   /* make bigger half using major_dt_ */
690   int true_value = 0;
691   for (size_t i = 0; i < sub_graphs->size(); i++) {
692     if (cor_group.at(i)) {
693       true_value += sub_graphs->at(i).cost_.cost();
694     }
695   }
696 
697   if (true_value < except_value) {
698     (void)std::transform(cor_group.begin(), cor_group.end(), cor_group.begin(), [](bool value) { return !value; });
699   }
700 
701   for (size_t i = 0; i < sub_graphs->size(); i++) {
702     if (cor_group.at(i)) {
703       sub_graphs->at(i).device_ = major_dt_;
704       sub_graphs->at(i).thread_ = major_thread_;
705       sub_graphs->at(i).tid_ = 0;
706     } else {
707       sub_graphs->at(i).device_ = minor_dt_;
708       sub_graphs->at(i).thread_ = minor_thread_;
709       sub_graphs->at(i).tid_ = 1;
710     }
711   }
712 }
713 
InitMainGraphDevice(DeviceType dt)714 void SearchSubGraph::InitMainGraphDevice(DeviceType dt) {
715   LiteGraph::SubGraph *main_graph = model_->graph_.sub_graphs_.front();
716   for (uint32_t node_index : main_graph->node_indices_) {
717     LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
718     node->device_type_ = dt;
719   }
720 }
721 
SubgraphFusion(std::vector<Subgraph> * sub_graphs)722 void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
723   while (sub_graphs->size() > kDefaultSubGraphSize) {
724     size_t sub1_index = 0;
725     size_t sub2_index = 0;
726     bool is_found = false;
727     for (sub1_index = 0; sub1_index < sub_graphs->size(); sub1_index++) {
728       for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs->size(); tmp2++) {
729         if (sub_graphs->at(sub1_index).tid_ == sub_graphs->at(tmp2).tid_) {
730           sub2_index = tmp2;
731           is_found = true;
732           break;
733         }
734       }
735       if (is_found) {
736         break;
737       }
738     }
739     MS_ASSERT(sub2_index > sub1_index); /* erase sub2 then sub1 */
740 
741     Subgraph new_sub;
742     new_sub.device_ = sub_graphs->at(sub1_index).device_;
743     new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
744     new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
745     new_sub.cost_ = sub_graphs->at(sub1_index).cost_ + sub_graphs->at(sub2_index).cost_;
746 
747     Subgraph &sub1 = sub_graphs->at(sub1_index);
748     Subgraph &sub2 = sub_graphs->at(sub2_index);
749     new_sub.nodes_.insert(new_sub.nodes_.end(), sub1.nodes_.begin(), sub1.nodes_.end());
750     new_sub.nodes_.insert(new_sub.nodes_.end(), sub2.nodes_.begin(), sub2.nodes_.end());
751     new_sub.heads_.insert(new_sub.heads_.end(), sub1.heads_.begin(), sub1.heads_.end());
752     new_sub.heads_.insert(new_sub.heads_.end(), sub2.heads_.begin(), sub2.heads_.end());
753     new_sub.ends_.insert(new_sub.ends_.end(), sub1.ends_.begin(), sub1.ends_.end());
754     new_sub.ends_.insert(new_sub.ends_.end(), sub2.ends_.begin(), sub2.ends_.end());
755     sub_graphs->erase(sub_graphs->begin() + sub2_index);
756     sub_graphs->erase(sub_graphs->begin() + sub1_index);
757     sub_graphs->insert(sub_graphs->end(), std::move(new_sub));
758   }
759 
760   return;
761 }
762 
CalculateCostModel(std::vector<Subgraph> * sub_graphs)763 void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
764   total_cost_ = 0;
765   for (Subgraph &subgraph : *sub_graphs) {
766     subgraph.cost_.empty();
767     std::vector<uint32_t> nodes = subgraph.nodes_;
768     for (uint32_t node_index : nodes) {
769       CostModel cost;
770       cost.io_cost_ = 0;
771       cost.mul_cost_ = 1;
772 
773       LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
774       if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) {
775         cost = CalculateConv2DFusion(node);
776       }
777 
778       subgraph.cost_ = subgraph.cost_ + cost;
779       total_cost_ += static_cast<size_t>(cost.cost());
780     }
781   }
782 }
783 
SubGraphSplitByOutput()784 void SearchSubGraph::SubGraphSplitByOutput() {
785   if (output_nodes_->size() < kDefaultSubGraphSize) {
786     return;
787   }
788 
789   InitSearchSubGraphByOutput();
790   CalculateCostModel(&sub_graphs_);
791   InitSubgraphRuntimeInfo(&sub_graphs_);
792   SubgraphFusion(&sub_graphs_);
793   for (Subgraph &sub : sub_graphs_) {
794     CheckSubHeadEnd(&sub);
795   }
796 
797   if (sub_graphs_.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
798       sub_graphs_.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
799     return;
800   }
801 
802   ConvertSubGraphToModel(&sub_graphs_);
803 }
804 
SubGraphSplitByMiddle()805 void SearchSubGraph::SubGraphSplitByMiddle() {
806   InitSearchSubGraphByMiddle();
807   for (auto map : node_sub_map_) {
808     std::vector<Subgraph> &subgraphs = map.second;
809     if (subgraphs.size() < kDefaultSubGraphSize) {
810       continue;
811     }
812 
813     CalculateCostModel(&subgraphs);
814     if (total_cost_ < kMinSubgraphCost) {
815       continue;
816     }
817 
818     InitSubgraphRuntimeInfo(&subgraphs);
819     SubgraphFusion(&subgraphs);
820 
821     MS_ASSERT(subgraphs.size() == kDefaultSubGraphSize);
822     if (subgraphs.at(kDefaultFirstSubgraph).nodes_.empty() || subgraphs.at(kDefaultSecondSubgraph).nodes_.empty()) {
823       continue;
824     }
825 
826     OptimizeAfterFusion(&subgraphs, map.first);
827 
828     /* redo cost-model and pre-set-info after optimize */
829     CalculateCostModel(&subgraphs);
830     if (subgraphs.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
831         subgraphs.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
832       continue;
833     }
834 
835     InitSubgraphRuntimeInfo(&subgraphs);
836 
837     InitMainGraphDevice(DT_CPU);
838 
839     ConvertSubGraphToModel(&subgraphs);
840   }
841 }
842 
SubGraphSplitByOffLineParallel()843 void SearchSubGraph::SubGraphSplitByOffLineParallel() {
844   sub_graphs_.clear();
845   node_list_ = model_->graph_.all_nodes_;
846 
847   std::vector<uint32_t> multy_in_nodes;
848 
849   SearchMultyInNodes(&multy_in_nodes);
850 
851   for (uint32_t node_index : multy_in_nodes) {
852     LiteGraph::Node *node = node_list_[node_index];
853     if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) != schema::PrimitiveType_Concat) {
854       continue;
855     }
856     std::vector<Subgraph> node_subs;
857     for (uint32_t input_tensor_index : node->input_indices_) {
858       Tensor *tensor = &tensors_[input_tensor_index];
859       if (tensor->type_ == CONST) continue;
860       std::vector<uint32_t> input_nodes = tensor->out_nodes_;
861       Subgraph sub;
862       sub.ends_.push_back(input_nodes[0]);
863       InsertNodeByMid(input_nodes[0], &sub, input_nodes[0]);
864       node_subs.push_back(sub);
865     }
866     node_sub_map_.insert(std::make_pair(node_index, node_subs));
867   }
868 
869   for (auto map : node_sub_map_) {
870     std::vector<Subgraph> &subgraphs = map.second;
871 
872     if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
873       continue;
874     }
875 
876     if (!CheckIsParallelSubGraph(subgraphs)) {
877       continue;
878     }
879 
880     // init graph device type
881     for (auto &subgraph : subgraphs) {
882       uint32_t head_node_index = subgraph.heads_.front();
883       subgraph.device_ = static_cast<lite::DeviceType>(model_->graph_.all_nodes_.at(head_node_index)->device_type_);
884       if (subgraph.device_ == DT_GPU) {
885         subgraph.thread_ = major_thread_;
886         subgraph.tid_ = 0;
887       } else {
888         subgraph.thread_ = minor_thread_;
889         subgraph.tid_ = 1;
890       }
891     }
892     ConvertSubGraphToModel(&subgraphs);
893   }
894   InitMainGraphDevice(DT_CPU);
895 }
896 
SearchSubGraph(const InnerContext * context,Model * model,std::vector<lite::Tensor * > * src_tensors,const std::map<int,OpParameter * > * op_parameters,std::vector<size_t> * output_nodes)897 SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
898                                const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
899     : output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) {
900   model_ = reinterpret_cast<LiteModel *>(model);
901 
902   major_dt_ = DT_CPU;
903   minor_dt_ = DT_CPU;
904   if (context_->IsNpuEnabled()) {
905     major_dt_ = DT_NPU;
906   } else if (context_->IsGpuEnabled()) {
907     major_dt_ = DT_GPU;
908   }
909 
910   if (major_dt_ == DT_GPU) {
911     major_thread_ = 1;
912     minor_thread_ = context_->thread_num_ - 1;
913   } else if (major_dt_ == DT_CPU) {
914     major_thread_ = UP_DIV(context_->thread_num_, kDefaultSubGraphSize);
915     minor_thread_ = static_cast<size_t>(context_->thread_num_ - major_thread_);
916   }
917   MS_ASSERT(major_thread_ > 0);
918   MS_ASSERT(minor_thread_ > 0);
919 
920   InitSearchTensor();
921   return;
922 }
923 
InsertParallelNode(uint32_t index,Subgraph * subgraph)924 void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
925   if (subgraph == nullptr) {
926     return;
927   }
928   if (subgraph->search_terminate_) {
929     if (!subgraph->nodes_.empty()) {
930       sub_graphs_.push_back(std::move(*subgraph));
931     }
932     Subgraph new_graph;
933     subgraph = &new_graph;
934   }
935   LiteGraph::Node *node = node_list_[index];
936   //  has been searched
937   if (node == nullptr) {
938     return;
939   }
940 
941   // if current node is parallel target node
942   if (IsOfflineParallelNode(node->primitive_, node->device_type_)) {
943     // first searched
944     if (subgraph->nodes_.empty()) {
945       subgraph->device_ = static_cast<DeviceType>(node->device_type_);
946     } else {
947       // check pre_device_type equal to current device_type
948       if (subgraph->device_ != static_cast<DeviceType>(node->device_type_)) {
949         return;
950       }
951     }
952     subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
953     node_list_[index] = nullptr;
954   } else {
955     subgraph->search_terminate_ = true;
956   }
957 
958   // just deal with parallel target node
959   std::vector<uint32_t> input = node->input_indices_;
960 
961   /* remove const node */
962   for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
963     if (tensors_[input[i]].type_ == CONST) {
964       VectorErase(&input, input[i]);
965     }
966   }
967 
968   // search to graph to graph input , terminate it.
969   if (std::any_of(input.begin(), input.end(),
970                   [&](uint32_t input_index) { return tensors_[input_index].type_ == INPUT; })) {
971     subgraph->search_terminate_ = true;
972     return;
973   }
974 
975   // search for next nodes
976   for (uint32_t next : input) {
977     auto next_nodes = tensors_[next].out_nodes_;
978     for (uint32_t next_node : next_nodes) {
979       InsertParallelNode(next_node, subgraph);
980     }
981   }
982 }
CheckSubHeadEnd(Subgraph * sub)983 void SearchSubGraph::CheckSubHeadEnd(Subgraph *sub) {
984   /* head-end node may error after subgraph fusion  */
985   /* sub head node check */
986   std::vector<uint32_t> delete_head;
987   for (uint32_t head_node : sub->heads_) {
988     if (std::find(sub->nodes_.begin(), sub->nodes_.end(), head_node) == sub->nodes_.end()) {
989       delete_head.push_back(head_node);
990       continue;
991     }
992     LiteGraph::Node *node = model_->graph_.all_nodes_.at(head_node);
993     std::vector<uint32_t> in_tensors = node->input_indices_;
994     std::vector<uint32_t> in_nodes;
995     for (uint32_t in_t : in_tensors) {
996       in_nodes.insert(in_nodes.begin(), tensors_.at(in_t).out_nodes_.begin(), tensors_.at(in_t).out_nodes_.end());
997     }
998 
999     if (in_nodes.empty()) {
1000       continue;
1001     }
1002 
1003     bool erase_head = true;
1004     for (uint32_t in_n : in_nodes) {
1005       if (std::find(sub->nodes_.begin(), sub->nodes_.end(), in_n) == sub->nodes_.end()) {
1006         erase_head = false;
1007         break;
1008       }
1009     }
1010     if (erase_head) {
1011       delete_head.push_back(head_node);
1012     }
1013   }
1014   for (uint32_t head : delete_head) {
1015     VectorErase(&sub->heads_, head);
1016   }
1017 
1018   /* sub end node check */
1019   std::vector<uint32_t> delete_end;
1020   for (uint32_t end_node : sub->ends_) {
1021     if (std::find(sub->nodes_.begin(), sub->nodes_.end(), end_node) == sub->nodes_.end()) {
1022       delete_end.push_back(end_node);
1023     }
1024   }
1025   for (uint32_t end : delete_end) {
1026     VectorErase(&sub->ends_, end);
1027   }
1028   return;
1029 }
1030 
ValidInParallel()1031 bool SearchSubGraph::ValidInParallel() {
1032   LiteGraph::Node *front_node = model_->graph_.all_nodes_.at(0);
1033   if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
1034     return false;
1035   }
1036   if (major_dt_ == DT_NPU) {
1037     return false;
1038   }
1039   if (model_->graph_.sub_graphs_.size() > 1) {
1040     return false;
1041   }
1042   if (model_->GetSchemaVersion() != SCHEMA_VERSION::SCHEMA_CUR) {
1043     return false;
1044   }
1045   return true;
1046 }
1047 
SubGraphSplit()1048 void SearchSubGraph::SubGraphSplit() {
1049   if (!ValidInParallel()) {
1050     return;
1051   }
1052 
1053   UpdateOfflineParallelFlag();
1054   if (offline_parallel_enable_) {
1055     SubGraphSplitByOffLineParallel();
1056   } else {
1057     SubGraphSplitByOutput();
1058     SubGraphSplitByMiddle();
1059   }
1060   return;
1061 }
1062 }  // namespace mindspore::lite
1063