• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/litert/sub_graph_split.h"
18 #include <cstdlib>
19 #include <utility>
20 #include <algorithm>
21 #include <iterator>
22 #include <vector>
23 #include <queue>
24 #include "src/tensor.h"
25 #include "schema/ops_generated.h"
26 #include "schema/model_generated.h"
27 #include "src/common/ops/populate/populate_register.h"
28 #include "src/litert/scheduler.h"
29 #include "src/litert/tensor_category.h"
30 #include "nnacl/pooling_parameter.h"
31 #include "include/model.h"
32 #include "nnacl/base/conv_common_base.h"
33 
34 namespace {
35 constexpr const int kMaxDepth = 2048;
36 constexpr const int kInitBufferSize = 1024;
37 constexpr int kOperatorMaxThreadNum = 16;
38 }  // namespace
39 
40 namespace mindspore::lite {
CommConvMul(std::vector<int> weight_shape,std::vector<int> output_shape)41 size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
42   size_t cost =
43     static_cast<size_t>(output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] * output_shape[NHWC_C] *
44                         weight_shape[NHWC_H] * weight_shape[NHWC_W] * weight_shape[NHWC_C]);
45   return cost;
46 }
47 
WinogradConvMul()48 size_t WinogradConvMul() {
49   /* winograd conv */
50   return 0;
51 }
52 
CommConvdwMul(std::vector<int> weight_shape,std::vector<int> output_shape)53 size_t CommConvdwMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
54   size_t cost = static_cast<size_t>(output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] *
55                                     output_shape[NHWC_C] * weight_shape[NHWC_H] * weight_shape[NHWC_W]);
56   return cost;
57 }
58 
WinogradConvDwMul()59 size_t WinogradConvDwMul() {
60   /* winograd convdw */
61   return 0;
62 }
63 
IsOfflineParallelNode(const void * node_primitive,int node_device_type)64 bool IsOfflineParallelNode(const void *node_primitive, int node_device_type) {
65   if (node_primitive == nullptr) {
66     return false;
67   }
68   return (GetPrimitiveType(node_primitive, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) &&
69          (node_device_type != kDefaultDeviceType);
70 }
71 
UpdateOfflineParallelFlag()72 void SearchSubGraph::UpdateOfflineParallelFlag() {
73   if (model_ == nullptr) {
74     offline_parallel_enable_ = false;
75     return;
76   }
77   // visited whole models to find any conv && depthwise conv have been set to device type
78   offline_parallel_enable_ = std::any_of(
79     this->model_->graph_.all_nodes_.begin(), this->model_->graph_.all_nodes_.end(),
80     [&](const lite::LiteGraph::Node *node) { return IsOfflineParallelNode(node->primitive_, node->device_type_); });
81 }
82 
CheckIsParallelSubGraph(const std::vector<Subgraph> & subgraphs)83 bool SearchSubGraph::CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs) {
84   if (subgraphs.size() != kDefaultSubGraphSize) {
85     return false;
86   }
87 
88   for (const auto &sub_graph : subgraphs) {
89     auto heads = sub_graph.heads_;
90     auto ends = sub_graph.ends_;
91     if (heads.size() != kDefaultInputs || ends.size() != kDefaultInputs) {
92       return false;
93     }
94     auto head_node = model_->graph_.all_nodes_.at(heads.front());
95     auto end_node = model_->graph_.all_nodes_.at(ends.front());
96     if (!IsOfflineParallelNode(head_node->primitive_, head_node->device_type_) ||
97         !IsOfflineParallelNode(end_node->primitive_, end_node->device_type_)) {
98       return false;
99     }
100 
101     // 1. check head_node's input is SplitOverlap node
102     for (const auto &input : head_node->input_indices_) {
103       if (tensors_.at(input).type_ == CONSTANT) {
104         continue;
105       }
106       auto input_node_index = tensors_.at(input).out_nodes_.front();
107       if (GetPrimitiveType(model_->graph_.all_nodes_.at(input_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
108           schema::PrimitiveType_SplitWithOverlap) {
109         return false;
110       }
111     }
112 
113     // 2. check end_node's output is concat node
114     for (const auto &output : end_node->output_indices_) {
115       if (tensors_.at(output).type_ == CONSTANT) {
116         continue;
117       }
118       auto output_node_index = tensors_.at(output).in_nodes_.front();
119       if (GetPrimitiveType(model_->graph_.all_nodes_.at(output_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
120           schema::PrimitiveType_Concat) {
121         return false;
122       }
123     }
124   }
125   return true;
126 }
127 
dfs(int i,int n,int current_sum,int except_value,int * min_value,std::vector<bool> * tmp_group,std::vector<bool> * cor_group,std::vector<Subgraph> * sub_graphs)128 void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
129                          std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
130   if (i > kMaxDepth) {
131     return;
132   }
133   if (i == n) {
134     if (abs(except_value - current_sum) < *min_value) {
135       for (int j = 0; j < n; j++) {
136         cor_group->at(j) = tmp_group->at(j);
137       }
138     }
139     *min_value = MSMIN(*min_value, abs(except_value - current_sum));
140     return;
141   }
142 
143   {
144     tmp_group->at(i) = true;
145     int next_sum = current_sum + sub_graphs->at(i).cost_.cost();
146     dfs(i + 1, n, next_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
147   }
148 
149   {
150     tmp_group->at(i) = false;
151     dfs(i + 1, n, current_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
152   }
153   return;
154 }
155 
CalculateConv2DFusion(const LiteGraph::Node * node)156 SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(const LiteGraph::Node *node) {
157   CostModel cost;
158   std::vector<uint32_t> inputs = node->input_indices_;
159   std::vector<uint32_t> outputs = node->output_indices_;
160 
161   std::vector<int> weight_shape = src_tensors_->at(inputs[1])->shape();
162   std::vector<int> output_shape = src_tensors_->at(outputs[0])->shape();
163 
164   ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameters_->at(outputs[0]));
165 
166   if (param->group_ == 1) {
167     if (param->kernel_h_ == 1 && param->kernel_w_ == 1) {
168       size_t conv1x1_mul_cost = CommConvMul(weight_shape, output_shape);
169       cost.mul_cost_ += conv1x1_mul_cost;
170     } else {
171       int out_unit;
172       if (CheckIfUseWinograd(&out_unit, param)) {
173         size_t winograd_conv_cost = CommConvMul(weight_shape, output_shape);
174         cost.mul_cost_ += winograd_conv_cost;
175       } else {
176         size_t comm_conv_mul_cost = CommConvMul(weight_shape, output_shape);
177         cost.mul_cost_ += comm_conv_mul_cost;
178       }
179     }
180   } else if (param->group_ == param->input_channel_ && param->group_ == param->output_channel_) {
181 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
182     if (CheckConvDw1DWinograd(param, context_->thread_num_)) {
183       /* ConvolutionDepthwise3x3CPUKernel */
184       size_t winograd_convdw_cost = CommConvdwMul(weight_shape, output_shape);
185       cost.mul_cost_ += winograd_convdw_cost;
186     } else {
187       /* ConvolutionDepthwiseIndirectCPUKernel */
188       /* ConvolutionDepthwiseSWCPUKernel */
189       /* ConvolutionDepthwiseCPUKernel */
190       size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
191       cost.mul_cost_ += comm_convdw_cost;
192     }
193 #else
194     size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
195     cost.mul_cost_ += comm_convdw_cost;
196 #endif
197   } else {
198     /* group conv */
199   }
200   return cost;
201 }
202 
CreatePartialPrimitive(int64_t subgraph_index)203 const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) {
204   flatbuffers::FlatBufferBuilder fbb(kInitBufferSize);
205   auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index);
206   auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o);
207   fbb.Finish(prim_offset);
208   auto tmp_buf = fbb.GetBufferPointer();
209   void *prim_buf = malloc(fbb.GetSize());
210   if (prim_buf == nullptr) {
211     return nullptr;
212   }
213   memcpy(prim_buf, tmp_buf, fbb.GetSize());
214 
215   auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf);
216   fbb.Clear();
217 
218   model_->node_bufs_.push_back(prim_buf);
219   return std::move(primitive);
220 }
221 
ConvertSubGraphToModel(std::vector<Subgraph> * sub_graphs)222 void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
223   LiteGraph::SubGraph *main_graphs = model_->graph_.sub_graphs_.front();
224 
225   for (Subgraph &subgraph : *sub_graphs) {
226     if (subgraph.nodes_.empty()) {
227       continue;
228     }
229 
230     DeviceType device_type = subgraph.device_;
231     size_t thread_num = subgraph.thread_;
232     int new_sub_index = static_cast<int>(model_->graph_.sub_graphs_.size());
233     int partial_index = static_cast<int>(model_->graph_.all_nodes_.size());
234     int particial_replace_index = partial_index;
235 
236     LiteGraph::SubGraph *new_sub_graph = new (std::nothrow) LiteGraph::SubGraph();
237     if (new_sub_graph == nullptr) {
238       MS_LOG(ERROR) << "New sub graph failed!";
239       return;
240     }
241     new_sub_graph->name_ = "SubSplit" + std::to_string(new_sub_index);
242     LiteGraph::Node *new_partial_node = new (std::nothrow) LiteGraph::Node();
243     if (new_partial_node == nullptr) {
244       MS_LOG(ERROR) << "New partial node failed!";
245       delete new_sub_graph;
246       return;
247     }
248     new_partial_node->name_ = "SubSplitPartial" + std::to_string(new_sub_index);
249     if (device_type == DT_CPU) {
250       new_partial_node->name_ = "Cpu" + new_partial_node->name_;
251     } else if (device_type == DT_GPU) {
252       new_partial_node->name_ = "Gpu" + new_partial_node->name_;
253     } else if (device_type == DT_NPU) {
254       new_partial_node->name_ = "Npu" + new_partial_node->name_;
255     }
256 
257     new_partial_node->node_type_ = static_cast<int>(mindspore::lite::NodeType_ValueNode);
258     new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
259 
260     while (!subgraph.nodes_.empty()) {
261       uint32_t node_index = subgraph.nodes_.front();
262       LiteGraph::Node *cur_node = model_->graph_.all_nodes_[node_index];
263       new_sub_graph->node_indices_.push_back(node_index);
264 
265       auto iter = find(main_graphs->node_indices_.begin(), main_graphs->node_indices_.end(), node_index);
266       int cur_node_index = std::distance(std::begin(main_graphs->node_indices_), iter);
267       particial_replace_index = (cur_node_index < particial_replace_index) ? cur_node_index : particial_replace_index;
268 
269       VectorErase(&main_graphs->node_indices_, node_index);
270       VectorErase(&subgraph.nodes_, node_index);
271       cur_node->device_type_ = static_cast<int>(device_type);
272       op_parameters_->at(cur_node->output_indices_.at(0))->thread_num_ = static_cast<int>(thread_num);
273     }
274 
275     for (uint32_t head_index : subgraph.heads_) {
276       LiteGraph::Node *head_node = model_->graph_.all_nodes_[head_index];
277       std::vector<uint32_t> inputs = head_node->input_indices_;
278       for (auto input : inputs) {
279         if (tensors_[input].type_ == CONSTANT) {
280           continue;
281         }
282         if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) !=
283             new_sub_graph->input_indices_.end()) {
284           continue;
285         }
286 
287         auto in_tensor_in_nodes = tensors_[input].out_nodes_;
288         if (!in_tensor_in_nodes.empty()) {
289           uint32_t in_tensor_in_node = in_tensor_in_nodes[0];
290           if (std::find(new_sub_graph->node_indices_.begin(), new_sub_graph->node_indices_.end(), in_tensor_in_node) !=
291               new_sub_graph->node_indices_.end()) {
292             continue;
293           }
294         }
295 
296         new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input);
297         new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input);
298       }
299     }
300 
301     for (uint32_t end_index : subgraph.ends_) {
302       LiteGraph::Node *end_node = model_->graph_.all_nodes_[end_index];
303       std::vector<uint32_t> outputs = end_node->output_indices_;
304       new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end());
305       new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end());
306     }
307 
308     main_graphs->node_indices_.insert(main_graphs->node_indices_.begin() + particial_replace_index, partial_index);
309     model_->graph_.all_nodes_.push_back(std::move(new_partial_node));
310     model_->graph_.sub_graphs_.push_back(std::move(new_sub_graph));
311   }
312 
313   sub_graphs->clear();
314   return;
315 }
316 
IsNodeSubGraphHead(uint32_t node_index,const std::vector<uint32_t> & ready_nodes)317 bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
318   std::vector<uint32_t> output_indexes = model_->graph_.all_nodes_.at(node_index)->output_indices_;
319   std::vector<uint32_t> output_nodes;
320   for (uint32_t out_t : output_indexes) {
321     std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
322     output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
323   }
324   if (output_indexes.size() == 1 && output_nodes.size() == 1) {
325     return false;
326   }
327   for (uint32_t out_n : output_nodes) {
328     if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
329       return true;
330     }
331   }
332   return false;
333 }
334 
IsNodeSubGraphHeadWithRoot(uint32_t node_index,const std::vector<uint32_t> & ready_nodes,uint32_t root_node_index)335 bool SearchSubGraph::IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
336                                                 uint32_t root_node_index) {
337   std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_;
338   std::vector<uint32_t> output_nodes;
339   for (uint32_t out_t : output_indexes) {
340     std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
341     output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
342   }
343   for (uint32_t out_n : output_nodes) {
344     if (root_node_index != out_n) {
345       if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
346         return true;
347       }
348     }
349   }
350   return false;
351 }
352 
SearchMultyInNodes(std::vector<uint32_t> * multy_in_nodes)353 void SearchSubGraph::SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes) {
354   std::vector<uint32_t> all_main_sub_nodes = model_->graph_.sub_graphs_[0]->node_indices_;
355   for (size_t i = 0; i < all_main_sub_nodes.size(); i++) {
356     uint32_t node_index = all_main_sub_nodes[i];
357     LiteGraph::Node *node = node_list_[node_index];
358 
359     if (IsPartialNode(node->primitive_, model_->GetSchemaVersion())) {
360       continue;
361     }
362     int input_count = static_cast<int>(
363       std::count_if(node->input_indices_.begin(), node->input_indices_.end(),
364                     [&](uint32_t in_tensor_index) { return tensors_[in_tensor_index].type_ != CONSTANT; }));
365     if (input_count > 1) {
366       multy_in_nodes->push_back(node_index);
367     }
368   }
369   return;
370 }
371 
RemoveConstNode(std::vector<uint32_t> * nodes)372 void SearchSubGraph::RemoveConstNode(std::vector<uint32_t> *nodes) {
373   bool stop_search = false;
374   while (!stop_search) {
375     stop_search = true;
376     for (size_t i = 0; i < nodes->size(); i++) {
377       if (tensors_[nodes->at(i)].type_ == CONSTANT) {
378         VectorErase(nodes, nodes->at(i));
379         stop_search = false;
380         break;
381       }
382     }
383   }
384 
385   return;
386 }
387 
InsertNode(uint32_t index,Subgraph * subgraph,uint32_t last_index)388 void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph, uint32_t last_index) {
389   if (subgraph->search_terminate_) {
390     return;
391   }
392 
393   LiteGraph::Node *node = node_list_.at(index);
394   if (node == nullptr) {
395     return;
396   }
397 
398   std::vector<uint32_t> input = node->input_indices_;
399   RemoveConstNode(&input);
400 
401   /* all node_input is graph_input */
402   for (size_t i = 0; i < input.size(); i++) {
403     if (tensors_[input[i]].type_ != INPUT) {
404       break;
405     }
406     subgraph->heads_.clear();
407     subgraph->ends_.clear();
408     subgraph->nodes_.clear();
409     subgraph->search_terminate_ = true;
410     return;
411   }
412 
413   /* split in graph */
414   if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
415     if (subgraph->nodes_.empty()) {
416       subgraph->search_terminate_ = true;
417       return;
418     }
419     subgraph->heads_.push_back(last_index);
420     return;
421   }
422 
423   if (find(output_nodes_->begin(), output_nodes_->end(), index) != output_nodes_->end()) {
424     subgraph->ends_.push_back(index);
425   }
426 
427   /* node insert in current subgraph */
428   subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
429   node_list_.at(index) = nullptr;
430 
431   /* search for next node */
432   for (uint32_t in : input) {
433     auto next_nodes = tensors_[in].out_nodes_;
434     for (uint32_t next_node : next_nodes) {
435       InsertNode(next_node, subgraph, index);
436     }
437   }
438   return;
439 }
440 
OptimizeAfterFusion(std::vector<Subgraph> * sub_graphs,uint32_t root_node_index)441 void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index) {
442   MS_ASSERT(sub_graphs->size() == kDefaultSubGraphSize);
443   for (Subgraph &sub : *sub_graphs) {
444     if (sub.nodes_.empty()) {
445       return;
446     }
447     int head_size = static_cast<int>(sub.heads_.size());
448     std::vector<uint32_t> used_heads;
449     for (int i = 0; i < head_size; i++) {
450       uint32_t head_node_index = sub.heads_.at(i);
451       if (std::find(used_heads.begin(), used_heads.end(), head_node_index) != used_heads.end()) {
452         break;
453       }
454       std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node_index]->input_indices_;
455       RemoveConstNode(&head_input_tensors);
456       if (head_input_tensors.size() != 1) continue;
457 
458       std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
459       if (input_nodes.size() != 1) continue;
460       uint32_t input_node_index = input_nodes.at(0);
461 
462       std::vector<uint32_t> input_tensors = model_->graph_.all_nodes_[input_node_index]->input_indices_;
463       RemoveConstNode(&input_tensors);
464       if (input_tensors.size() != 1) continue;
465 
466       /* this node qualified:
467        * 1. the only input node of current head node
468        * 2. all output included in current subgraph
469        * 3. one input-tensor */
470       if (!IsNodeSubGraphHeadWithRoot(input_node_index, sub.nodes_, root_node_index)) {
471         InsertHeadNode(input_node_index, &sub);
472         used_heads.push_back(head_node_index); /* delete used head at end */
473       }
474       head_size = static_cast<int>(sub.heads_.size());
475     }
476     for (auto head_index : used_heads) {
477       VectorErase(&sub.heads_, head_index);
478     }
479 
480     CheckSubHeadEnd(&sub);
481 
482     /* sort node index  */
483     std::sort(sub.nodes_.begin(), sub.nodes_.end());
484   }
485 }
486 
InsertHeadNode(uint32_t head_node_index,Subgraph * subgraph)487 void SearchSubGraph::InsertHeadNode(uint32_t head_node_index, Subgraph *subgraph) {
488   LiteGraph::Node *node = node_list_.at(head_node_index);
489   std::vector<uint32_t> head_node_inputs = node->input_indices_;
490   RemoveConstNode(&head_node_inputs);
491 
492   subgraph->nodes_.push_back(head_node_index);
493   node_list_.at(head_node_index) = nullptr;
494 
495   /* search for next node */
496   size_t current_node_size = subgraph->nodes_.size();
497   for (uint32_t in : head_node_inputs) {
498     auto next_nodes = tensors_[in].out_nodes_;
499     for (uint32_t next_node : next_nodes) {
500       InsertNodeByMid(next_node, subgraph, head_node_index);
501     }
502   }
503 
504   if (current_node_size == subgraph->nodes_.size()) {
505     subgraph->heads_.push_back(head_node_index);
506   }
507   return;
508 }
509 
InsertNodeByMid(uint32_t node_index,Subgraph * subgraph,uint32_t last_index)510 void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph, uint32_t last_index) {
511   LiteGraph::Node *node = node_list_.at(node_index);
512   MS_CHECK_PTR_IF_NULL(node);
513 
514   auto subs_iter = node_sub_map_.find(node_index);
515   if (subs_iter != node_sub_map_.end()) {
516     /* node is multy-in node , already searched before */
517 
518     if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
519       /* this node can not be included in this subgraph */
520       if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(last_index);
521       return;
522     }
523 
524     /* include this multy-in-unit in current subgraph */
525     std::vector<Subgraph> &subs = subs_iter->second;
526 
527     /* insert nodes */
528     subgraph->nodes_.push_back(node_index);
529     for (Subgraph &sub : subs) {
530       subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
531     }
532 
533     /* insert heads */
534     std::set<uint32_t> subs_head;
535     subs_head.insert(node_index);
536     for (Subgraph &sub : subs) {
537       for (uint32_t head : sub.heads_) {
538         subs_head.insert(head);
539       }
540     }
541 
542     std::set<uint32_t> subs_head_baklist = subs_head;
543     for (uint32_t head_node : subs_head) {
544       std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node]->input_indices_;
545       RemoveConstNode(&head_input_tensors);
546       if (head_input_tensors.size() != 1) continue;
547       std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
548       if (input_nodes.size() != 1) continue;
549 
550       uint32_t input_node = input_nodes.at(0);
551       if (!IsNodeSubGraphHead(input_node, subgraph->nodes_)) {
552         InsertNodeByMid(input_node, subgraph, head_node);
553         subs_head_baklist.erase(head_node);
554       }
555     }
556 
557     /* stop search  */
558     for (auto head : subs_head_baklist) {
559       subgraph->heads_.push_back(head);
560     }
561     node_sub_map_.erase(node_index);
562     return;
563   }
564 
565   std::vector<uint32_t> inputs = node->input_indices_;
566   RemoveConstNode(&inputs);
567 
568   if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
569     if (!subgraph->nodes_.empty()) {
570       if (std::find(subgraph->heads_.begin(), subgraph->heads_.end(), last_index) == subgraph->heads_.end()) {
571         subgraph->heads_.push_back(last_index);
572       }
573     }
574     return;
575   }
576 
577   subgraph->nodes_.insert(subgraph->nodes_.begin(), node_index);
578   node_list_.at(node_index) = nullptr;
579 
580   /* search for next node */
581   for (uint32_t in : inputs) {
582     auto next_nodes = tensors_[in].out_nodes_;
583     if (next_nodes.size() == 0) {
584       if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(subgraph->nodes_.front());
585     } else {
586       for (uint32_t next_node : next_nodes) {
587         InsertNodeByMid(next_node, subgraph, node_index);
588       }
589     }
590   }
591   return;
592 }
593 
InitMiddleSubgraph(const std::vector<uint32_t> * multy_in_nodes)594 void SearchSubGraph::InitMiddleSubgraph(const std::vector<uint32_t> *multy_in_nodes) {
595   for (uint32_t node_index : *multy_in_nodes) {
596     std::vector<Subgraph> node_subs;
597     LiteGraph::Node *node = node_list_[node_index];
598     for (uint32_t input_tensor_index : node->input_indices_) {
599       Tensor *tensor = &tensors_[input_tensor_index];
600       if (tensor->type_ == CONSTANT || tensor->type_ == INPUT) continue;
601 
602       std::vector<uint32_t> input_nodes = tensor->out_nodes_;
603       if (input_nodes.empty()) continue;
604       if (input_nodes.size() != 1) continue;
605       uint32_t input_node = input_nodes[0];
606 
607       Subgraph sub;
608       sub.ends_.push_back(input_node);
609       InsertNodeByMid(input_node, &sub, input_node);
610       node_subs.push_back(sub);
611     }
612     if (!node_subs.empty()) {
613       node_sub_map_.insert(std::make_pair(node_index, node_subs));
614     }
615   }
616   return;
617 }
618 
InitSearchSubGraphByMiddle()619 void SearchSubGraph::InitSearchSubGraphByMiddle() {
620   sub_graphs_.clear();
621   node_list_ = model_->graph_.all_nodes_;
622 
623   std::vector<uint32_t> multy_in_nodes;
624 
625   SearchMultyInNodes(&multy_in_nodes);
626 
627   if (multy_in_nodes.size() > kMaxMultyInNode) {
628     node_sub_map_.clear();
629     return;
630   }
631 
632   InitMiddleSubgraph(&multy_in_nodes);
633 
634   if (node_sub_map_.size() > kMaxSubGraphCount) {
635     node_sub_map_.clear();
636   }
637   return;
638 }
639 
InitSearchSubGraphByOutput()640 void SearchSubGraph::InitSearchSubGraphByOutput() {
641   sub_graphs_.clear();
642   node_list_ = model_->graph_.all_nodes_;
643 
644   for (auto out : *output_nodes_) {
645     Subgraph subgraph;
646 
647     InsertNode(static_cast<uint32_t>(out), &subgraph, static_cast<uint32_t>(out));
648 
649     sub_graphs_.push_back(std::move(subgraph));
650   }
651   return;
652 }
653 
InitSearchTensor()654 void SearchSubGraph::InitSearchTensor() {
655   tensors_.resize(model_->graph_.all_tensors_.size());
656 
657   /* Set Tensor Type */
658   for (size_t i = 0; i < tensors_.size(); i++) {
659     tensors_[i].type_ = NORMAL;
660     mindspore::schema::Tensor *src_tensor = model_->graph_.all_tensors_[i];
661     if (src_tensor == nullptr) {
662       continue;
663     }
664     auto category = TensorCategory(*src_tensor);
665     if (category == mindspore::lite::Category::CONST_TENSOR || category == mindspore::lite::Category::CONST_SCALAR) {
666       tensors_[i].type_ = CONSTANT;
667     }
668   }
669   std::vector<uint32_t> graph_input = model_->graph_.sub_graphs_[0]->input_indices_;
670   for (auto in : graph_input) {
671     tensors_[in].type_ = INPUT;
672   }
673 
674   /* Set Tensor In and out Node */
675   for (size_t index = 0; index < model_->graph_.all_nodes_.size(); index++) {
676     LiteGraph::Node *node = model_->graph_.all_nodes_[index];
677     std::vector<uint32_t> input = node->input_indices_;
678     for (uint32_t in : input) {
679       tensors_[in].in_nodes_.push_back(index);
680     }
681     std::vector<uint32_t> output = node->output_indices_;
682     for (uint32_t out : output) {
683       tensors_[out].out_nodes_.push_back(index);
684     }
685   }
686   return;
687 }
688 
InitSubgraphRuntimeInfo(std::vector<Subgraph> * sub_graphs)689 void SearchSubGraph::InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs) {
690   std::vector<bool> tmp_group;
691   std::vector<bool> cor_group;
692 
693   tmp_group.resize(sub_graphs->size());
694   cor_group.resize(sub_graphs->size());
695 
696   int except_value = static_cast<int>(total_cost_ * kDefaultGpu); /* major device responsible for 50% calculation */
697   int min_value = INT32_MAX;
698 
699   dfs(0, static_cast<int>(sub_graphs->size()), 0, except_value, &min_value, &tmp_group, &cor_group, sub_graphs);
700 
701   /* make bigger half using major_dt_ */
702   int true_value = 0;
703   for (size_t i = 0; i < sub_graphs->size(); i++) {
704     if (cor_group.at(i)) {
705       true_value += sub_graphs->at(i).cost_.cost();
706     }
707   }
708 
709   if (true_value < except_value) {
710     (void)std::transform(cor_group.begin(), cor_group.end(), cor_group.begin(), [](bool value) { return !value; });
711   }
712 
713   for (size_t i = 0; i < sub_graphs->size(); i++) {
714     if (cor_group.at(i)) {
715       sub_graphs->at(i).device_ = major_dt_;
716       sub_graphs->at(i).thread_ = major_thread_;
717       sub_graphs->at(i).tid_ = 0;
718     } else {
719       sub_graphs->at(i).device_ = minor_dt_;
720       sub_graphs->at(i).thread_ = minor_thread_;
721       sub_graphs->at(i).tid_ = 1;
722     }
723   }
724 }
725 
InitMainGraphDevice(DeviceType dt)726 void SearchSubGraph::InitMainGraphDevice(DeviceType dt) {
727   LiteGraph::SubGraph *main_graph = model_->graph_.sub_graphs_.front();
728   for (uint32_t node_index : main_graph->node_indices_) {
729     LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
730     node->device_type_ = dt;
731   }
732 }
733 
SubgraphFusion(std::vector<Subgraph> * sub_graphs)734 void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
735   while (sub_graphs->size() > kDefaultSubGraphSize) {
736     size_t sub1_index = 0;
737     size_t sub2_index = 0;
738     bool is_found = false;
739     for (sub1_index = 0; sub1_index < sub_graphs->size(); sub1_index++) {
740       for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs->size(); tmp2++) {
741         if (sub_graphs->at(sub1_index).tid_ == sub_graphs->at(tmp2).tid_) {
742           sub2_index = tmp2;
743           is_found = true;
744           break;
745         }
746       }
747       if (is_found) {
748         break;
749       }
750     }
751     MS_ASSERT(sub2_index > sub1_index); /* erase sub2 then sub1 */
752 
753     Subgraph new_sub;
754     new_sub.device_ = sub_graphs->at(sub1_index).device_;
755     new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
756     new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
757     new_sub.cost_ = sub_graphs->at(sub1_index).cost_ + sub_graphs->at(sub2_index).cost_;
758 
759     Subgraph &sub1 = sub_graphs->at(sub1_index);
760     Subgraph &sub2 = sub_graphs->at(sub2_index);
761     new_sub.nodes_.insert(new_sub.nodes_.end(), sub1.nodes_.begin(), sub1.nodes_.end());
762     new_sub.nodes_.insert(new_sub.nodes_.end(), sub2.nodes_.begin(), sub2.nodes_.end());
763     new_sub.heads_.insert(new_sub.heads_.end(), sub1.heads_.begin(), sub1.heads_.end());
764     new_sub.heads_.insert(new_sub.heads_.end(), sub2.heads_.begin(), sub2.heads_.end());
765     new_sub.ends_.insert(new_sub.ends_.end(), sub1.ends_.begin(), sub1.ends_.end());
766     new_sub.ends_.insert(new_sub.ends_.end(), sub2.ends_.begin(), sub2.ends_.end());
767     sub_graphs->erase(sub_graphs->begin() + sub2_index);
768     sub_graphs->erase(sub_graphs->begin() + sub1_index);
769     sub_graphs->insert(sub_graphs->end(), std::move(new_sub));
770   }
771 
772   return;
773 }
774 
CalculateCostModel(std::vector<Subgraph> * sub_graphs)775 void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
776   total_cost_ = 0;
777   for (Subgraph &subgraph : *sub_graphs) {
778     subgraph.cost_.empty();
779     std::vector<uint32_t> nodes = subgraph.nodes_;
780     for (uint32_t node_index : nodes) {
781       CostModel cost;
782       cost.io_cost_ = 0;
783       cost.mul_cost_ = 1;
784 
785       LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
786       if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) {
787         cost = CalculateConv2DFusion(node);
788       }
789 
790       subgraph.cost_ = subgraph.cost_ + cost;
791       total_cost_ += static_cast<size_t>(cost.cost());
792     }
793   }
794 }
795 
SubGraphSplitByOutput()796 void SearchSubGraph::SubGraphSplitByOutput() {
797   if (output_nodes_->size() < kDefaultSubGraphSize) {
798     return;
799   }
800 
801   InitSearchSubGraphByOutput();
802   CalculateCostModel(&sub_graphs_);
803   InitSubgraphRuntimeInfo(&sub_graphs_);
804   SubgraphFusion(&sub_graphs_);
805   for (Subgraph &sub : sub_graphs_) {
806     CheckSubHeadEnd(&sub);
807   }
808 
809   if (sub_graphs_.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
810       sub_graphs_.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
811     return;
812   }
813 
814   ConvertSubGraphToModel(&sub_graphs_);
815 }
816 
SubGraphSplitByMiddle()817 void SearchSubGraph::SubGraphSplitByMiddle() {
818   InitSearchSubGraphByMiddle();
819   for (auto map : node_sub_map_) {
820     std::vector<Subgraph> &subgraphs = map.second;
821     if (subgraphs.size() < kDefaultSubGraphSize) {
822       continue;
823     }
824 
825     CalculateCostModel(&subgraphs);
826     if (total_cost_ < kMinSubgraphCost) {
827       continue;
828     }
829 
830     InitSubgraphRuntimeInfo(&subgraphs);
831     SubgraphFusion(&subgraphs);
832 
833     MS_ASSERT(subgraphs.size() == kDefaultSubGraphSize);
834     if (subgraphs.at(kDefaultFirstSubgraph).nodes_.empty() || subgraphs.at(kDefaultSecondSubgraph).nodes_.empty()) {
835       continue;
836     }
837 
838     OptimizeAfterFusion(&subgraphs, map.first);
839 
840     /* redo cost-model and pre-set-info after optimize */
841     CalculateCostModel(&subgraphs);
842     if (subgraphs.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
843         subgraphs.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
844       continue;
845     }
846 
847     InitSubgraphRuntimeInfo(&subgraphs);
848 
849     InitMainGraphDevice(DT_CPU);
850 
851     ConvertSubGraphToModel(&subgraphs);
852   }
853 }
854 
SubGraphSplitByOffLineParallel()855 void SearchSubGraph::SubGraphSplitByOffLineParallel() {
856   sub_graphs_.clear();
857   node_list_ = model_->graph_.all_nodes_;
858 
859   std::vector<uint32_t> multy_in_nodes;
860 
861   SearchMultyInNodes(&multy_in_nodes);
862 
863   for (uint32_t node_index : multy_in_nodes) {
864     LiteGraph::Node *node = node_list_[node_index];
865     if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) != schema::PrimitiveType_Concat) {
866       continue;
867     }
868     std::vector<Subgraph> node_subs;
869     for (uint32_t input_tensor_index : node->input_indices_) {
870       Tensor *tensor = &tensors_[input_tensor_index];
871       if (tensor->type_ == CONSTANT) continue;
872       std::vector<uint32_t> input_nodes = tensor->out_nodes_;
873       Subgraph sub;
874       sub.ends_.push_back(input_nodes[0]);
875       InsertNodeByMid(input_nodes[0], &sub, input_nodes[0]);
876       node_subs.push_back(sub);
877     }
878     node_sub_map_.insert(std::make_pair(node_index, node_subs));
879   }
880 
881   for (auto map : node_sub_map_) {
882     std::vector<Subgraph> &subgraphs = map.second;
883 
884     if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
885       continue;
886     }
887 
888     if (!CheckIsParallelSubGraph(subgraphs)) {
889       continue;
890     }
891 
892     // init graph device type
893     for (auto &subgraph : subgraphs) {
894       uint32_t head_node_index = subgraph.heads_.front();
895       subgraph.device_ = static_cast<lite::DeviceType>(model_->graph_.all_nodes_.at(head_node_index)->device_type_);
896       if (subgraph.device_ == DT_GPU) {
897         subgraph.thread_ = major_thread_;
898         subgraph.tid_ = 0;
899       } else {
900         subgraph.thread_ = minor_thread_;
901         subgraph.tid_ = 1;
902       }
903     }
904     ConvertSubGraphToModel(&subgraphs);
905   }
906   InitMainGraphDevice(DT_CPU);
907 }
908 
SearchSubGraph(const InnerContext * context,Model * model,std::vector<lite::Tensor * > * src_tensors,const std::map<int,OpParameter * > * op_parameters,std::vector<size_t> * output_nodes)909 SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
910                                const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
911     : context_(context), src_tensors_(src_tensors), output_nodes_(output_nodes), op_parameters_(op_parameters) {
912   model_ = reinterpret_cast<LiteModel *>(model);
913 
914   major_dt_ = DT_CPU;
915   minor_dt_ = DT_CPU;
916   if (context_->IsDeviceTypeEnabled(DT_NPU)) {
917     major_dt_ = DT_NPU;
918   } else if (context_->IsDeviceTypeEnabled(DT_GPU)) {
919     major_dt_ = DT_GPU;
920   }
921 
922   if (major_dt_ == DT_GPU) {
923     major_thread_ = 1;
924     minor_thread_ = context_->thread_num_ - 1;
925   } else if (major_dt_ == DT_CPU) {
926     major_thread_ = UP_DIV(context_->thread_num_, kDefaultSubGraphSize);
927     minor_thread_ = static_cast<size_t>(context_->thread_num_ - major_thread_);
928   }
929 
930   InitSearchTensor();
931   return;
932 }
933 
InsertParallelNode(uint32_t index,Subgraph * subgraph)934 void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
935   if (subgraph == nullptr) {
936     return;
937   }
938   if (subgraph->search_terminate_) {
939     if (!subgraph->nodes_.empty()) {
940       sub_graphs_.push_back(std::move(*subgraph));
941     }
942     Subgraph new_graph;
943     subgraph = &new_graph;
944   }
945   LiteGraph::Node *node = node_list_[index];
946   //  has been searched
947   if (node == nullptr) {
948     return;
949   }
950 
951   // if current node is parallel target node
952   if (IsOfflineParallelNode(node->primitive_, node->device_type_)) {
953     // first searched
954     if (subgraph->nodes_.empty()) {
955       subgraph->device_ = static_cast<DeviceType>(node->device_type_);
956     } else {
957       // check pre_device_type equal to current device_type
958       if (subgraph->device_ != static_cast<DeviceType>(node->device_type_)) {
959         return;
960       }
961     }
962     subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
963     node_list_[index] = nullptr;
964   } else {
965     subgraph->search_terminate_ = true;
966   }
967 
968   // just deal with parallel target node
969   std::vector<uint32_t> input = node->input_indices_;
970 
971   /* remove const node */
972   for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
973     if (tensors_[input[i]].type_ == CONSTANT) {
974       VectorErase(&input, input[i]);
975     }
976   }
977 
978   // search to graph to graph input , terminate it.
979   if (std::any_of(input.begin(), input.end(),
980                   [&](uint32_t input_index) { return tensors_[input_index].type_ == INPUT; })) {
981     subgraph->search_terminate_ = true;
982     return;
983   }
984 
985   // search for next nodes
986   for (uint32_t next : input) {
987     auto next_nodes = tensors_[next].out_nodes_;
988     for (uint32_t next_node : next_nodes) {
989       InsertParallelNode(next_node, subgraph);
990     }
991   }
992 }
CheckSubHeadEnd(Subgraph * sub)993 void SearchSubGraph::CheckSubHeadEnd(Subgraph *sub) {
994   /* head-end node may error after subgraph fusion  */
995   /* sub head node check */
996   std::vector<uint32_t> delete_head;
997   for (uint32_t head_node : sub->heads_) {
998     if (std::find(sub->nodes_.begin(), sub->nodes_.end(), head_node) == sub->nodes_.end()) {
999       delete_head.push_back(head_node);
1000       continue;
1001     }
1002     LiteGraph::Node *node = model_->graph_.all_nodes_.at(head_node);
1003     std::vector<uint32_t> in_tensors = node->input_indices_;
1004     std::vector<uint32_t> in_nodes;
1005     for (uint32_t in_t : in_tensors) {
1006       in_nodes.insert(in_nodes.begin(), tensors_.at(in_t).out_nodes_.begin(), tensors_.at(in_t).out_nodes_.end());
1007     }
1008 
1009     if (in_nodes.empty()) {
1010       continue;
1011     }
1012 
1013     bool erase_head = true;
1014     for (uint32_t in_n : in_nodes) {
1015       if (std::find(sub->nodes_.begin(), sub->nodes_.end(), in_n) == sub->nodes_.end()) {
1016         erase_head = false;
1017         break;
1018       }
1019     }
1020     if (erase_head) {
1021       delete_head.push_back(head_node);
1022     }
1023   }
1024   for (uint32_t head : delete_head) {
1025     VectorErase(&sub->heads_, head);
1026   }
1027 
1028   /* sub end node check */
1029   std::vector<uint32_t> delete_end;
1030   for (uint32_t end_node : sub->ends_) {
1031     if (std::find(sub->nodes_.begin(), sub->nodes_.end(), end_node) == sub->nodes_.end()) {
1032       delete_end.push_back(end_node);
1033     }
1034   }
1035   for (uint32_t end : delete_end) {
1036     VectorErase(&sub->ends_, end);
1037   }
1038   return;
1039 }
1040 
ValidInParallel()1041 bool SearchSubGraph::ValidInParallel() {
1042   LiteGraph::Node *front_node = model_->graph_.all_nodes_.at(0);
1043   if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
1044     return false;
1045   }
1046   if (major_thread_ < 1 || minor_thread_ < 1) {
1047     return false;
1048   }
1049   if (major_dt_ == DT_NPU) {
1050     return false;
1051   }
1052   if (model_->graph_.sub_graphs_.size() > 1) {
1053     return false;
1054   }
1055   if (model_->GetSchemaVersion() != SCHEMA_VERSION::SCHEMA_CUR) {
1056     return false;
1057   }
1058   return true;
1059 }
1060 
SubGraphSplit()1061 void SearchSubGraph::SubGraphSplit() {
1062   if (!ValidInParallel()) {
1063     return;
1064   }
1065 
1066   UpdateOfflineParallelFlag();
1067   if (offline_parallel_enable_) {
1068     SubGraphSplitByOffLineParallel();
1069   } else {
1070     SubGraphSplitByOutput();
1071     SubGraphSplitByMiddle();
1072   }
1073   return;
1074 }
1075 
InsertNodeBegin(uint32_t index,Subgraph * subgraph,std::vector<size_t> * outputs)1076 void SearchSubGraph::InsertNodeBegin(uint32_t index, Subgraph *subgraph, std::vector<size_t> *outputs) {
1077   size_t last_index = index;
1078 
1079   while (1) {
1080     LiteGraph::Node *node = node_list_.at(index);
1081     if (node == nullptr) {
1082       subgraph->heads_.push_back(last_index);
1083       return;
1084     }
1085 
1086     std::vector<uint32_t> input = node->input_indices_;
1087     RemoveConstNode(&input);
1088 
1089     /* all node_input is graph_input */
1090     for (size_t i = 0; i < input.size(); i++) {
1091       if (tensors_[input[i]].type_ != INPUT) {
1092         break;
1093       }
1094       subgraph->heads_.push_back(last_index);
1095       return;
1096     }
1097 
1098     /* split in graph */
1099     if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
1100       if (subgraph->nodes_.empty()) {
1101         subgraph->heads_.push_back(index);
1102         subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
1103         node_list_.at(index) = nullptr;
1104         for (uint32_t in : input) {
1105           auto next_nodes = tensors_[in].out_nodes_;
1106           std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs));
1107         }
1108         return;
1109       }
1110       subgraph->heads_.push_back(last_index);
1111       outputs->push_back(index);
1112       return;
1113     }
1114 
1115     for (uint32_t in : input) {
1116       auto next_nodes = tensors_[in].out_nodes_;
1117       std::copy(next_nodes.begin(), next_nodes.end(), std::back_inserter(*outputs));
1118     }
1119     subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
1120     node_list_.at(index) = nullptr;
1121     if (outputs->size() == 1) {
1122       last_index = index;
1123       index = static_cast<uint32_t>(outputs->at(0));
1124       outputs->clear();
1125     } else {
1126       subgraph->heads_.push_back(index);
1127       return;
1128     }
1129   }
1130 
1131   return;
1132 }
1133 
SubGraphSplitByOperator()1134 void SearchSubGraph::SubGraphSplitByOperator() {
1135   if (!ValidInParallel()) {
1136     return;
1137   }
1138   sub_graphs_.clear();
1139   node_list_ = model_->graph_.all_nodes_;
1140   std::queue<size_t> outputs{};
1141   for (auto out : *output_nodes_) {
1142     outputs.push(out);
1143   }
1144   std::vector<size_t> outputs_vec{};
1145   while (!outputs.empty()) {
1146     auto out = outputs.front();
1147     outputs.pop();
1148 
1149     Subgraph subgraph;
1150     subgraph.ends_.push_back(out);
1151     subgraph.device_ = DT_CPU;
1152     subgraph.thread_ = context_->thread_num_ > kOperatorMaxThreadNum ? kOperatorMaxThreadNum : context_->thread_num_;
1153 
1154     InsertNodeBegin(static_cast<uint32_t>(out), &subgraph, &outputs_vec);
1155     for (auto new_out : outputs_vec) {
1156       outputs.push(new_out);
1157     }
1158     outputs_vec.clear();
1159     if (!subgraph.nodes_.empty()) {
1160       sub_graphs_.push_back(std::move(subgraph));
1161     }
1162   }
1163   ConvertSubGraphToModel(&sub_graphs_);
1164 }
1165 }  // namespace mindspore::lite
1166