1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/sub_graph_split.h"
18 #include <stdlib.h>
19 #include <utility>
20 #include <algorithm>
21 #include <iterator>
22 #include <vector>
23 #include "src/tensor.h"
24 #include "schema/ops_generated.h"
25 #include "schema/model_generated.h"
26 #include "src/ops/populate/populate_register.h"
27 #include "src/scheduler.h"
28 #include "nnacl/pooling_parameter.h"
29 #include "include/model.h"
30 #include "nnacl/base/conv_common_base.h"
31
32 namespace mindspore::lite {
CommConvMul(std::vector<int> weight_shape,std::vector<int> output_shape)33 size_t CommConvMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
34 size_t cost = output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] * output_shape[NHWC_C] *
35 weight_shape[NHWC_H] * weight_shape[NHWC_W] * weight_shape[NHWC_C];
36 return cost;
37 }
38
WinogradConvMul()39 size_t WinogradConvMul() {
40 /* winograd conv */
41 return 0;
42 }
43
CommConvdwMul(std::vector<int> weight_shape,std::vector<int> output_shape)44 size_t CommConvdwMul(std::vector<int> weight_shape, std::vector<int> output_shape) {
45 size_t cost = static_cast<size_t>(output_shape[NHWC_N] * output_shape[NHWC_H] * output_shape[NHWC_W] *
46 output_shape[NHWC_C] * weight_shape[NHWC_H] * weight_shape[NHWC_W]);
47 return cost;
48 }
49
WinogradConvDwMul()50 size_t WinogradConvDwMul() {
51 /* winograd convdw */
52 return 0;
53 }
54
IsOfflineParallelNode(const void * node_primitive,int node_device_type)55 bool IsOfflineParallelNode(const void *node_primitive, int node_device_type) {
56 if (node_primitive == nullptr) {
57 return false;
58 }
59 return (GetPrimitiveType(node_primitive, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) &&
60 (node_device_type != kDefaultDeviceType);
61 }
62
UpdateOfflineParallelFlag()63 void SearchSubGraph::UpdateOfflineParallelFlag() {
64 if (model_ == nullptr) {
65 offline_parallel_enable_ = false;
66 return;
67 }
68 // visited whole models to find any conv && depthwise conv have been set to device type
69 offline_parallel_enable_ = std::any_of(
70 this->model_->graph_.all_nodes_.begin(), this->model_->graph_.all_nodes_.end(),
71 [&](lite::LiteGraph::Node *node) { return IsOfflineParallelNode(node->primitive_, node->device_type_); });
72 }
73
CheckIsParallelSubGraph(const std::vector<Subgraph> & subgraphs)74 bool SearchSubGraph::CheckIsParallelSubGraph(const std::vector<Subgraph> &subgraphs) {
75 if (subgraphs.size() != kDefaultSubGraphSize) {
76 return false;
77 }
78
79 for (const auto &sub_graph : subgraphs) {
80 auto heads = sub_graph.heads_;
81 auto ends = sub_graph.ends_;
82 if (heads.size() != kDefaultInputs || ends.size() != kDefaultInputs) {
83 return false;
84 }
85 auto head_node = model_->graph_.all_nodes_.at(heads.front());
86 auto end_node = model_->graph_.all_nodes_.at(ends.front());
87 if (!IsOfflineParallelNode(head_node->primitive_, head_node->device_type_) ||
88 !IsOfflineParallelNode(end_node->primitive_, end_node->device_type_)) {
89 return false;
90 }
91
92 // 1. check head_node's input is SplitOverlap node
93 for (const auto &input : head_node->input_indices_) {
94 if (tensors_.at(input).type_ == CONST) {
95 continue;
96 }
97 auto input_node_index = tensors_.at(input).out_nodes_.front();
98 if (GetPrimitiveType(model_->graph_.all_nodes_.at(input_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
99 schema::PrimitiveType_SplitWithOverlap) {
100 return false;
101 }
102 }
103
104 // 2. check end_node's output is concat node
105 for (const auto &output : end_node->output_indices_) {
106 if (tensors_.at(output).type_ == CONST) {
107 continue;
108 }
109 auto output_node_index = tensors_.at(output).in_nodes_.front();
110 if (GetPrimitiveType(model_->graph_.all_nodes_.at(output_node_index)->primitive_, SCHEMA_VERSION::SCHEMA_CUR) !=
111 schema::PrimitiveType_Concat) {
112 return false;
113 }
114 }
115 }
116 return true;
117 }
118
dfs(int i,int n,int current_sum,int except_value,int * min_value,std::vector<bool> * tmp_group,std::vector<bool> * cor_group,std::vector<Subgraph> * sub_graphs)119 void SearchSubGraph::dfs(int i, int n, int current_sum, int except_value, int *min_value, std::vector<bool> *tmp_group,
120 std::vector<bool> *cor_group, std::vector<Subgraph> *sub_graphs) {
121 if (i == n) {
122 if (abs(except_value - current_sum) < *min_value) {
123 for (int j = 0; j < n; j++) {
124 cor_group->at(j) = tmp_group->at(j);
125 }
126 }
127 *min_value = MSMIN(*min_value, abs(except_value - current_sum));
128 return;
129 }
130
131 {
132 tmp_group->at(i) = true;
133 int next_sum = current_sum + sub_graphs->at(i).cost_.cost();
134 dfs(i + 1, n, next_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
135 }
136
137 {
138 tmp_group->at(i) = false;
139 dfs(i + 1, n, current_sum, except_value, min_value, tmp_group, cor_group, sub_graphs);
140 }
141 return;
142 }
143
CalculateConv2DFusion(LiteGraph::Node * node)144 SearchSubGraph::CostModel SearchSubGraph::CalculateConv2DFusion(LiteGraph::Node *node) {
145 CostModel cost;
146 std::vector<uint32_t> inputs = node->input_indices_;
147 std::vector<uint32_t> outputs = node->output_indices_;
148
149 std::vector<int> weight_shape = src_tensors_->at(inputs[1])->shape();
150 std::vector<int> output_shape = src_tensors_->at(outputs[0])->shape();
151
152 ConvParameter *param = reinterpret_cast<ConvParameter *>(op_parameters_->at(outputs[0]));
153
154 if (param->group_ == 1) {
155 if (param->kernel_h_ == 1 && param->kernel_w_ == 1) {
156 size_t conv1x1_mul_cost = CommConvMul(weight_shape, output_shape);
157 cost.mul_cost_ += conv1x1_mul_cost;
158 } else {
159 int out_unit;
160 if (CheckIfUseWinograd(&out_unit, param)) {
161 size_t winograd_conv_cost = CommConvMul(weight_shape, output_shape);
162 cost.mul_cost_ += winograd_conv_cost;
163 } else {
164 size_t comm_conv_mul_cost = CommConvMul(weight_shape, output_shape);
165 cost.mul_cost_ += comm_conv_mul_cost;
166 }
167 }
168 } else if (param->group_ == param->input_channel_ && param->group_ == param->output_channel_) {
169 #if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX))
170 if (CheckConvDw1DWinograd(param, context_->thread_num_)) {
171 /* ConvolutionDepthwise3x3CPUKernel */
172 size_t winograd_convdw_cost = CommConvdwMul(weight_shape, output_shape);
173 cost.mul_cost_ += winograd_convdw_cost;
174 } else {
175 /* ConvolutionDepthwiseIndirectCPUKernel */
176 /* ConvolutionDepthwiseSWCPUKernel */
177 /* ConvolutionDepthwiseCPUKernel */
178 size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
179 cost.mul_cost_ += comm_convdw_cost;
180 }
181 #else
182 size_t comm_convdw_cost = CommConvdwMul(weight_shape, output_shape);
183 cost.mul_cost_ += comm_convdw_cost;
184 #endif
185 } else {
186 /* group conv */
187 }
188 return cost;
189 }
190
CreatePartialPrimitive(int64_t subgraph_index)191 const schema::Primitive *SearchSubGraph::CreatePartialPrimitive(int64_t subgraph_index) {
192 flatbuffers::FlatBufferBuilder fbb(1024);
193 auto val_offset = schema::CreatePartialFusion(fbb, subgraph_index);
194 auto prim_offset = schema::CreatePrimitive(fbb, schema::PrimitiveType_PartialFusion, val_offset.o);
195 fbb.Finish(prim_offset);
196 auto tmp_buf = fbb.GetBufferPointer();
197 void *prim_buf = malloc(fbb.GetSize());
198 if (prim_buf == nullptr) {
199 return nullptr;
200 }
201 memcpy(prim_buf, tmp_buf, fbb.GetSize());
202
203 auto primitive = flatbuffers::GetRoot<schema::Primitive>(prim_buf);
204 fbb.Clear();
205
206 model_->node_bufs_.push_back(prim_buf);
207 return std::move(primitive);
208 }
209
ConvertSubGraphToModel(std::vector<Subgraph> * sub_graphs)210 void SearchSubGraph::ConvertSubGraphToModel(std::vector<Subgraph> *sub_graphs) {
211 if (sub_graphs->size() != kDefaultSubGraphSize) {
212 return;
213 }
214 LiteGraph::SubGraph *main_graphs = model_->graph_.sub_graphs_.front();
215
216 for (Subgraph &subgraph : *sub_graphs) {
217 if (subgraph.nodes_.empty()) {
218 continue;
219 }
220
221 DeviceType device_type = subgraph.device_;
222 size_t thread_num = subgraph.thread_;
223 int new_sub_index = static_cast<int>(model_->graph_.sub_graphs_.size());
224 int partial_index = static_cast<int>(model_->graph_.all_nodes_.size());
225 int particial_replace_index = partial_index;
226
227 LiteGraph::SubGraph *new_sub_graph = new (std::nothrow) LiteGraph::SubGraph();
228 if (new_sub_graph == nullptr) {
229 MS_LOG(ERROR) << "New sub graph failed!";
230 return;
231 }
232 new_sub_graph->name_ = "SubSplit" + std::to_string(new_sub_index);
233 LiteGraph::Node *new_partial_node = new (std::nothrow) LiteGraph::Node();
234 if (new_partial_node == nullptr) {
235 MS_LOG(ERROR) << "New partial node failed!";
236 delete new_sub_graph;
237 return;
238 }
239 new_partial_node->name_ = "SubSplitPartial" + std::to_string(new_sub_index);
240 if (device_type == DT_CPU) {
241 new_partial_node->name_ = "Cpu" + new_partial_node->name_;
242 } else if (device_type == DT_GPU) {
243 new_partial_node->name_ = "Gpu" + new_partial_node->name_;
244 } else if (device_type == DT_NPU) {
245 new_partial_node->name_ = "Npu" + new_partial_node->name_;
246 }
247
248 new_partial_node->node_type_ = static_cast<int>(mindspore::lite::MSNodeType::NodeType_ValueNode);
249 new_partial_node->primitive_ = CreatePartialPrimitive(new_sub_index);
250
251 while (!subgraph.nodes_.empty()) {
252 uint32_t node_index = subgraph.nodes_.front();
253 LiteGraph::Node *cur_node = model_->graph_.all_nodes_[node_index];
254 new_sub_graph->node_indices_.push_back(node_index);
255
256 auto iter = find(main_graphs->node_indices_.begin(), main_graphs->node_indices_.end(), node_index);
257 int cur_node_index = std::distance(std::begin(main_graphs->node_indices_), iter);
258 particial_replace_index = (cur_node_index < particial_replace_index) ? cur_node_index : particial_replace_index;
259
260 VectorErase(&main_graphs->node_indices_, node_index);
261 VectorErase(&subgraph.nodes_, node_index);
262 cur_node->device_type_ = static_cast<int>(device_type);
263 op_parameters_->at(static_cast<int>(cur_node->output_indices_.at(0)))->thread_num_ = thread_num;
264 }
265
266 for (uint32_t head_index : subgraph.heads_) {
267 LiteGraph::Node *head_node = model_->graph_.all_nodes_[head_index];
268 std::vector<uint32_t> inputs = head_node->input_indices_;
269 for (auto input : inputs) {
270 if (tensors_[input].type_ == CONST) {
271 continue;
272 }
273 if (std::find(new_sub_graph->input_indices_.begin(), new_sub_graph->input_indices_.end(), input) !=
274 new_sub_graph->input_indices_.end()) {
275 continue;
276 }
277
278 auto in_tensor_in_nodes = tensors_[input].out_nodes_;
279 if (!in_tensor_in_nodes.empty()) {
280 uint32_t in_tensor_in_node = in_tensor_in_nodes[0];
281 if (std::find(new_sub_graph->node_indices_.begin(), new_sub_graph->node_indices_.end(), in_tensor_in_node) !=
282 new_sub_graph->node_indices_.end()) {
283 continue;
284 }
285 }
286
287 new_sub_graph->input_indices_.insert(new_sub_graph->input_indices_.end(), input);
288 new_partial_node->input_indices_.insert(new_partial_node->input_indices_.end(), input);
289 }
290 }
291
292 for (uint32_t end_index : subgraph.ends_) {
293 LiteGraph::Node *end_node = model_->graph_.all_nodes_[end_index];
294 std::vector<uint32_t> outputs = end_node->output_indices_;
295 new_sub_graph->output_indices_.insert(new_sub_graph->output_indices_.end(), outputs.begin(), outputs.end());
296 new_partial_node->output_indices_.insert(new_partial_node->output_indices_.end(), outputs.begin(), outputs.end());
297 }
298
299 main_graphs->node_indices_.insert(main_graphs->node_indices_.begin() + particial_replace_index, partial_index);
300 model_->graph_.all_nodes_.push_back(std::move(new_partial_node));
301 model_->graph_.sub_graphs_.push_back(std::move(new_sub_graph));
302 }
303
304 sub_graphs->clear();
305 return;
306 }
307
IsNodeSubGraphHead(uint32_t node_index,const std::vector<uint32_t> & ready_nodes)308 bool SearchSubGraph::IsNodeSubGraphHead(uint32_t node_index, const std::vector<uint32_t> &ready_nodes) {
309 std::vector<uint32_t> output_indexes = model_->graph_.all_nodes_.at(node_index)->output_indices_;
310 std::vector<uint32_t> output_nodes;
311 for (uint32_t out_t : output_indexes) {
312 std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
313 output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
314 }
315 if (output_indexes.size() == 1 && output_nodes.size() == 1) {
316 return false;
317 }
318 for (uint32_t out_n : output_nodes) {
319 if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
320 return true;
321 }
322 }
323 return false;
324 }
325
IsNodeSubGraphHeadWithRoot(uint32_t node_index,const std::vector<uint32_t> & ready_nodes,uint32_t root_node_index)326 bool SearchSubGraph::IsNodeSubGraphHeadWithRoot(uint32_t node_index, const std::vector<uint32_t> &ready_nodes,
327 uint32_t root_node_index) {
328 std::vector<uint32_t> output_indexes = node_list_.at(node_index)->output_indices_;
329 std::vector<uint32_t> output_nodes;
330 for (uint32_t out_t : output_indexes) {
331 std::vector<uint32_t> cur_nodes = tensors_[out_t].in_nodes_;
332 output_nodes.insert(output_nodes.end(), cur_nodes.begin(), cur_nodes.end());
333 }
334 for (uint32_t out_n : output_nodes) {
335 if (root_node_index != out_n) {
336 if (find(ready_nodes.begin(), ready_nodes.end(), out_n) == ready_nodes.end()) {
337 return true;
338 }
339 }
340 }
341 return false;
342 }
343
SearchMultyInNodes(std::vector<uint32_t> * multy_in_nodes)344 void SearchSubGraph::SearchMultyInNodes(std::vector<uint32_t> *multy_in_nodes) {
345 std::vector<uint32_t> all_main_sub_nodes = model_->graph_.sub_graphs_[0]->node_indices_;
346 for (size_t i = 0; i < all_main_sub_nodes.size(); i++) {
347 uint32_t node_index = all_main_sub_nodes[i];
348 LiteGraph::Node *node = node_list_[node_index];
349
350 if (IsPartialNode(node->primitive_, model_->GetSchemaVersion())) {
351 continue;
352 }
353 int input_count = std::count_if(node->input_indices_.begin(), node->input_indices_.end(),
354 [&](uint32_t in_tensor_index) { return tensors_[in_tensor_index].type_ != CONST; });
355 if (input_count > 1) {
356 multy_in_nodes->push_back(node_index);
357 }
358 }
359 return;
360 }
361
RemoveConstNode(std::vector<uint32_t> * nodes)362 void SearchSubGraph::RemoveConstNode(std::vector<uint32_t> *nodes) {
363 bool stop_search = false;
364 while (!stop_search) {
365 stop_search = true;
366 for (size_t i = 0; i < nodes->size(); i++) {
367 if (tensors_[nodes->at(i)].type_ == CONST) {
368 VectorErase(nodes, nodes->at(i));
369 stop_search = false;
370 break;
371 }
372 }
373 }
374
375 return;
376 }
377
InsertNode(uint32_t index,Subgraph * subgraph,uint32_t last_index)378 void SearchSubGraph::InsertNode(uint32_t index, Subgraph *subgraph, uint32_t last_index) {
379 if (subgraph->search_terminate_) {
380 return;
381 }
382
383 LiteGraph::Node *node = node_list_.at(index);
384 if (node == nullptr) {
385 return;
386 }
387
388 std::vector<uint32_t> input = node->input_indices_;
389 RemoveConstNode(&input);
390
391 /* all node_input is graph_input */
392 for (size_t i = 0; i < input.size(); i++) {
393 if (tensors_[input[i]].type_ != INPUT) {
394 break;
395 }
396 subgraph->heads_.clear();
397 subgraph->ends_.clear();
398 subgraph->nodes_.clear();
399 subgraph->search_terminate_ = true;
400 return;
401 }
402
403 /* split in graph */
404 if (IsNodeSubGraphHead(index, subgraph->nodes_)) {
405 if (subgraph->nodes_.empty()) {
406 subgraph->search_terminate_ = true;
407 return;
408 }
409 subgraph->heads_.push_back(last_index);
410 return;
411 }
412
413 if (find(output_nodes_->begin(), output_nodes_->end(), index) != output_nodes_->end()) {
414 subgraph->ends_.push_back(index);
415 }
416
417 /* node insert in current subgraph */
418 subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
419 node_list_.at(index) = nullptr;
420
421 /* search for next node */
422 for (uint32_t in : input) {
423 auto next_nodes = tensors_[in].out_nodes_;
424 for (uint32_t next_node : next_nodes) {
425 InsertNode(next_node, subgraph, index);
426 }
427 }
428 return;
429 }
430
OptimizeAfterFusion(std::vector<Subgraph> * sub_graphs,uint32_t root_node_index)431 void SearchSubGraph::OptimizeAfterFusion(std::vector<Subgraph> *sub_graphs, uint32_t root_node_index) {
432 MS_ASSERT(sub_graphs->size() == kDefaultSubGraphSize);
433 for (Subgraph &sub : *sub_graphs) {
434 if (sub.nodes_.empty()) {
435 return;
436 }
437 int head_size = static_cast<int>(sub.heads_.size());
438 std::vector<uint32_t> used_heads;
439 for (int i = 0; i < head_size; i++) {
440 uint32_t head_node_index = sub.heads_.at(i);
441 if (std::find(used_heads.begin(), used_heads.end(), head_node_index) != used_heads.end()) {
442 break;
443 }
444 std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node_index]->input_indices_;
445 RemoveConstNode(&head_input_tensors);
446 if (head_input_tensors.size() != 1) continue;
447
448 std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
449 if (input_nodes.size() != 1) continue;
450 uint32_t input_node_index = input_nodes.at(0);
451
452 std::vector<uint32_t> input_tensors = model_->graph_.all_nodes_[input_node_index]->input_indices_;
453 RemoveConstNode(&input_tensors);
454 if (input_tensors.size() != 1) continue;
455
456 /* this node qualified:
457 * 1. the only input node of current head node
458 * 2. all output included in current subgraph
459 * 3. one input-tensor */
460 if (!IsNodeSubGraphHeadWithRoot(input_node_index, sub.nodes_, root_node_index)) {
461 InsertHeadNode(input_node_index, &sub);
462 used_heads.push_back(head_node_index); /* delete used head at end */
463 }
464 head_size = static_cast<int>(sub.heads_.size());
465 }
466 for (auto head_index : used_heads) {
467 VectorErase(&sub.heads_, head_index);
468 }
469
470 CheckSubHeadEnd(&sub);
471
472 /* sort node index */
473 std::sort(sub.nodes_.begin(), sub.nodes_.end());
474 }
475 }
476
InsertHeadNode(uint32_t head_node_index,Subgraph * subgraph)477 void SearchSubGraph::InsertHeadNode(uint32_t head_node_index, Subgraph *subgraph) {
478 LiteGraph::Node *node = node_list_.at(head_node_index);
479 std::vector<uint32_t> head_node_inputs = node->input_indices_;
480 RemoveConstNode(&head_node_inputs);
481
482 subgraph->nodes_.push_back(head_node_index);
483 node_list_.at(head_node_index) = nullptr;
484
485 /* search for next node */
486 size_t current_node_size = subgraph->nodes_.size();
487 for (uint32_t in : head_node_inputs) {
488 auto next_nodes = tensors_[in].out_nodes_;
489 for (uint32_t next_node : next_nodes) {
490 InsertNodeByMid(next_node, subgraph, head_node_index);
491 }
492 }
493
494 if (current_node_size == subgraph->nodes_.size()) {
495 subgraph->heads_.push_back(head_node_index);
496 }
497 return;
498 }
499
InsertNodeByMid(uint32_t node_index,Subgraph * subgraph,uint32_t last_index)500 void SearchSubGraph::InsertNodeByMid(uint32_t node_index, Subgraph *subgraph, uint32_t last_index) {
501 LiteGraph::Node *node = node_list_.at(node_index);
502 MS_CHECK_PTR_IF_NULL(node);
503
504 auto subs_iter = node_sub_map_.find(node_index);
505 if (subs_iter != node_sub_map_.end()) {
506 /* node is multy-in node , already searched before */
507
508 if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
509 /* this node can not be included in this subgraph */
510 if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(last_index);
511 return;
512 }
513
514 /* include this multy-in-unit in current subgraph */
515 std::vector<Subgraph> &subs = subs_iter->second;
516
517 /* insert nodes */
518 subgraph->nodes_.push_back(node_index);
519 for (Subgraph &sub : subs) {
520 subgraph->nodes_.insert(subgraph->nodes_.end(), sub.nodes_.begin(), sub.nodes_.end());
521 }
522
523 /* insert heads */
524 std::set<uint32_t> subs_head;
525 subs_head.insert(node_index);
526 for (Subgraph &sub : subs) {
527 for (uint32_t head : sub.heads_) {
528 subs_head.insert(head);
529 }
530 }
531
532 std::set<uint32_t> subs_head_baklist = subs_head;
533 for (uint32_t head_node : subs_head) {
534 std::vector<uint32_t> head_input_tensors = model_->graph_.all_nodes_[head_node]->input_indices_;
535 RemoveConstNode(&head_input_tensors);
536 if (head_input_tensors.size() != 1) continue;
537 std::vector<uint32_t> input_nodes = tensors_.at(head_input_tensors.at(0)).out_nodes_;
538 if (input_nodes.size() != 1) continue;
539
540 uint32_t input_node = input_nodes.at(0);
541 if (!IsNodeSubGraphHead(input_node, subgraph->nodes_)) {
542 InsertNodeByMid(input_node, subgraph, head_node);
543 subs_head_baklist.erase(head_node);
544 }
545 }
546
547 /* stop search */
548 for (auto head : subs_head_baklist) {
549 subgraph->heads_.push_back(head);
550 }
551 node_sub_map_.erase(node_index);
552 return;
553 }
554
555 std::vector<uint32_t> inputs = node->input_indices_;
556 RemoveConstNode(&inputs);
557
558 if (IsNodeSubGraphHead(node_index, subgraph->nodes_)) {
559 if (!subgraph->nodes_.empty()) {
560 if (std::find(subgraph->heads_.begin(), subgraph->heads_.end(), last_index) == subgraph->heads_.end()) {
561 subgraph->heads_.push_back(last_index);
562 }
563 }
564 return;
565 }
566
567 subgraph->nodes_.insert(subgraph->nodes_.begin(), node_index);
568 node_list_.at(node_index) = nullptr;
569
570 /* search for next node */
571 for (uint32_t in : inputs) {
572 auto next_nodes = tensors_[in].out_nodes_;
573 if (next_nodes.size() == 0) {
574 if (!subgraph->nodes_.empty()) subgraph->heads_.push_back(subgraph->nodes_.front());
575 } else {
576 for (uint32_t next_node : next_nodes) {
577 InsertNodeByMid(next_node, subgraph, node_index);
578 }
579 }
580 }
581 return;
582 }
583
InitMiddleSubgraph(std::vector<uint32_t> * multy_in_nodes)584 void SearchSubGraph::InitMiddleSubgraph(std::vector<uint32_t> *multy_in_nodes) {
585 for (uint32_t node_index : *multy_in_nodes) {
586 std::vector<Subgraph> node_subs;
587 LiteGraph::Node *node = node_list_[node_index];
588 for (uint32_t input_tensor_index : node->input_indices_) {
589 Tensor *tensor = &tensors_[input_tensor_index];
590 if (tensor->type_ == CONST || tensor->type_ == INPUT) continue;
591
592 std::vector<uint32_t> input_nodes = tensor->out_nodes_;
593 if (input_nodes.empty()) continue;
594 if (input_nodes.size() != 1) continue;
595 uint32_t input_node = input_nodes[0];
596
597 Subgraph sub;
598 sub.ends_.push_back(input_node);
599 InsertNodeByMid(input_node, &sub, input_node);
600 node_subs.push_back(sub);
601 }
602 if (!node_subs.empty()) {
603 node_sub_map_.insert(std::make_pair(node_index, node_subs));
604 }
605 }
606 return;
607 }
608
InitSearchSubGraphByMiddle()609 void SearchSubGraph::InitSearchSubGraphByMiddle() {
610 sub_graphs_.clear();
611 node_list_ = model_->graph_.all_nodes_;
612
613 std::vector<uint32_t> multy_in_nodes;
614
615 SearchMultyInNodes(&multy_in_nodes);
616
617 if (multy_in_nodes.size() > kMaxMultyInNode) {
618 node_sub_map_.clear();
619 return;
620 }
621
622 InitMiddleSubgraph(&multy_in_nodes);
623
624 if (node_sub_map_.size() > kMaxSubGraphCount) {
625 node_sub_map_.clear();
626 }
627 return;
628 }
629
InitSearchSubGraphByOutput()630 void SearchSubGraph::InitSearchSubGraphByOutput() {
631 sub_graphs_.clear();
632 node_list_ = model_->graph_.all_nodes_;
633
634 for (auto out : *output_nodes_) {
635 Subgraph subgraph;
636
637 InsertNode(static_cast<uint32_t>(out), &subgraph, static_cast<uint32_t>(out));
638
639 sub_graphs_.push_back(std::move(subgraph));
640 }
641 return;
642 }
643
InitSearchTensor()644 void SearchSubGraph::InitSearchTensor() {
645 tensors_.resize(model_->graph_.all_tensors_.size());
646
647 /* Set Tensor Type */
648 for (size_t i = 0; i < tensors_.size(); i++) {
649 tensors_[i].type_ = NORMAL;
650 mindspore::schema::Tensor *src_tensor = static_cast<schema::Tensor *>(model_->graph_.all_tensors_[i]);
651 auto category = TensorCategory(src_tensor);
652 if (category == mindspore::lite::Tensor::Category::CONST_TENSOR ||
653 category == mindspore::lite::Tensor::Category::CONST_SCALAR) {
654 tensors_[i].type_ = CONST;
655 }
656 }
657 std::vector<uint32_t> graph_input = model_->graph_.sub_graphs_[0]->input_indices_;
658 for (auto in : graph_input) {
659 tensors_[in].type_ = INPUT;
660 }
661
662 /* Set Tensor In and out Node */
663 for (size_t index = 0; index < model_->graph_.all_nodes_.size(); index++) {
664 LiteGraph::Node *node = model_->graph_.all_nodes_[index];
665 std::vector<uint32_t> input = node->input_indices_;
666 for (uint32_t in : input) {
667 tensors_[in].in_nodes_.push_back(index);
668 }
669 std::vector<uint32_t> output = node->output_indices_;
670 for (uint32_t out : output) {
671 tensors_[out].out_nodes_.push_back(index);
672 }
673 }
674 return;
675 }
676
InitSubgraphRuntimeInfo(std::vector<Subgraph> * sub_graphs)677 void SearchSubGraph::InitSubgraphRuntimeInfo(std::vector<Subgraph> *sub_graphs) {
678 std::vector<bool> tmp_group;
679 std::vector<bool> cor_group;
680
681 tmp_group.resize(sub_graphs->size());
682 cor_group.resize(sub_graphs->size());
683
684 int except_value = static_cast<int>(total_cost_ * kDefaultGpu); /* major device responsible for 50% calculation */
685 int min_value = INT32_MAX;
686
687 dfs(0, static_cast<int>(sub_graphs->size()), 0, except_value, &min_value, &tmp_group, &cor_group, sub_graphs);
688
689 /* make bigger half using major_dt_ */
690 int true_value = 0;
691 for (size_t i = 0; i < sub_graphs->size(); i++) {
692 if (cor_group.at(i)) {
693 true_value += sub_graphs->at(i).cost_.cost();
694 }
695 }
696
697 if (true_value < except_value) {
698 (void)std::transform(cor_group.begin(), cor_group.end(), cor_group.begin(), [](bool value) { return !value; });
699 }
700
701 for (size_t i = 0; i < sub_graphs->size(); i++) {
702 if (cor_group.at(i)) {
703 sub_graphs->at(i).device_ = major_dt_;
704 sub_graphs->at(i).thread_ = major_thread_;
705 sub_graphs->at(i).tid_ = 0;
706 } else {
707 sub_graphs->at(i).device_ = minor_dt_;
708 sub_graphs->at(i).thread_ = minor_thread_;
709 sub_graphs->at(i).tid_ = 1;
710 }
711 }
712 }
713
InitMainGraphDevice(DeviceType dt)714 void SearchSubGraph::InitMainGraphDevice(DeviceType dt) {
715 LiteGraph::SubGraph *main_graph = model_->graph_.sub_graphs_.front();
716 for (uint32_t node_index : main_graph->node_indices_) {
717 LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
718 node->device_type_ = dt;
719 }
720 }
721
SubgraphFusion(std::vector<Subgraph> * sub_graphs)722 void SearchSubGraph::SubgraphFusion(std::vector<Subgraph> *sub_graphs) {
723 while (sub_graphs->size() > kDefaultSubGraphSize) {
724 size_t sub1_index = 0;
725 size_t sub2_index = 0;
726 bool is_found = false;
727 for (sub1_index = 0; sub1_index < sub_graphs->size(); sub1_index++) {
728 for (size_t tmp2 = sub1_index + 1; tmp2 < sub_graphs->size(); tmp2++) {
729 if (sub_graphs->at(sub1_index).tid_ == sub_graphs->at(tmp2).tid_) {
730 sub2_index = tmp2;
731 is_found = true;
732 break;
733 }
734 }
735 if (is_found) {
736 break;
737 }
738 }
739 MS_ASSERT(sub2_index > sub1_index); /* erase sub2 then sub1 */
740
741 Subgraph new_sub;
742 new_sub.device_ = sub_graphs->at(sub1_index).device_;
743 new_sub.thread_ = sub_graphs->at(sub1_index).thread_;
744 new_sub.tid_ = sub_graphs->at(sub1_index).tid_;
745 new_sub.cost_ = sub_graphs->at(sub1_index).cost_ + sub_graphs->at(sub2_index).cost_;
746
747 Subgraph &sub1 = sub_graphs->at(sub1_index);
748 Subgraph &sub2 = sub_graphs->at(sub2_index);
749 new_sub.nodes_.insert(new_sub.nodes_.end(), sub1.nodes_.begin(), sub1.nodes_.end());
750 new_sub.nodes_.insert(new_sub.nodes_.end(), sub2.nodes_.begin(), sub2.nodes_.end());
751 new_sub.heads_.insert(new_sub.heads_.end(), sub1.heads_.begin(), sub1.heads_.end());
752 new_sub.heads_.insert(new_sub.heads_.end(), sub2.heads_.begin(), sub2.heads_.end());
753 new_sub.ends_.insert(new_sub.ends_.end(), sub1.ends_.begin(), sub1.ends_.end());
754 new_sub.ends_.insert(new_sub.ends_.end(), sub2.ends_.begin(), sub2.ends_.end());
755 sub_graphs->erase(sub_graphs->begin() + sub2_index);
756 sub_graphs->erase(sub_graphs->begin() + sub1_index);
757 sub_graphs->insert(sub_graphs->end(), std::move(new_sub));
758 }
759
760 return;
761 }
762
CalculateCostModel(std::vector<Subgraph> * sub_graphs)763 void SearchSubGraph::CalculateCostModel(std::vector<Subgraph> *sub_graphs) {
764 total_cost_ = 0;
765 for (Subgraph &subgraph : *sub_graphs) {
766 subgraph.cost_.empty();
767 std::vector<uint32_t> nodes = subgraph.nodes_;
768 for (uint32_t node_index : nodes) {
769 CostModel cost;
770 cost.io_cost_ = 0;
771 cost.mul_cost_ = 1;
772
773 LiteGraph::Node *node = model_->graph_.all_nodes_[node_index];
774 if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) == schema::PrimitiveType_Conv2DFusion) {
775 cost = CalculateConv2DFusion(node);
776 }
777
778 subgraph.cost_ = subgraph.cost_ + cost;
779 total_cost_ += static_cast<size_t>(cost.cost());
780 }
781 }
782 }
783
SubGraphSplitByOutput()784 void SearchSubGraph::SubGraphSplitByOutput() {
785 if (output_nodes_->size() < kDefaultSubGraphSize) {
786 return;
787 }
788
789 InitSearchSubGraphByOutput();
790 CalculateCostModel(&sub_graphs_);
791 InitSubgraphRuntimeInfo(&sub_graphs_);
792 SubgraphFusion(&sub_graphs_);
793 for (Subgraph &sub : sub_graphs_) {
794 CheckSubHeadEnd(&sub);
795 }
796
797 if (sub_graphs_.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
798 sub_graphs_.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
799 return;
800 }
801
802 ConvertSubGraphToModel(&sub_graphs_);
803 }
804
SubGraphSplitByMiddle()805 void SearchSubGraph::SubGraphSplitByMiddle() {
806 InitSearchSubGraphByMiddle();
807 for (auto map : node_sub_map_) {
808 std::vector<Subgraph> &subgraphs = map.second;
809 if (subgraphs.size() < kDefaultSubGraphSize) {
810 continue;
811 }
812
813 CalculateCostModel(&subgraphs);
814 if (total_cost_ < kMinSubgraphCost) {
815 continue;
816 }
817
818 InitSubgraphRuntimeInfo(&subgraphs);
819 SubgraphFusion(&subgraphs);
820
821 MS_ASSERT(subgraphs.size() == kDefaultSubGraphSize);
822 if (subgraphs.at(kDefaultFirstSubgraph).nodes_.empty() || subgraphs.at(kDefaultSecondSubgraph).nodes_.empty()) {
823 continue;
824 }
825
826 OptimizeAfterFusion(&subgraphs, map.first);
827
828 /* redo cost-model and pre-set-info after optimize */
829 CalculateCostModel(&subgraphs);
830 if (subgraphs.at(kDefaultFirstSubgraph).cost_.cost() < kMinSubgraphCost ||
831 subgraphs.at(kDefaultSecondSubgraph).cost_.cost() < kMinSubgraphCost) {
832 continue;
833 }
834
835 InitSubgraphRuntimeInfo(&subgraphs);
836
837 InitMainGraphDevice(DT_CPU);
838
839 ConvertSubGraphToModel(&subgraphs);
840 }
841 }
842
SubGraphSplitByOffLineParallel()843 void SearchSubGraph::SubGraphSplitByOffLineParallel() {
844 sub_graphs_.clear();
845 node_list_ = model_->graph_.all_nodes_;
846
847 std::vector<uint32_t> multy_in_nodes;
848
849 SearchMultyInNodes(&multy_in_nodes);
850
851 for (uint32_t node_index : multy_in_nodes) {
852 LiteGraph::Node *node = node_list_[node_index];
853 if (GetPrimitiveType(node->primitive_, SCHEMA_VERSION::SCHEMA_CUR) != schema::PrimitiveType_Concat) {
854 continue;
855 }
856 std::vector<Subgraph> node_subs;
857 for (uint32_t input_tensor_index : node->input_indices_) {
858 Tensor *tensor = &tensors_[input_tensor_index];
859 if (tensor->type_ == CONST) continue;
860 std::vector<uint32_t> input_nodes = tensor->out_nodes_;
861 Subgraph sub;
862 sub.ends_.push_back(input_nodes[0]);
863 InsertNodeByMid(input_nodes[0], &sub, input_nodes[0]);
864 node_subs.push_back(sub);
865 }
866 node_sub_map_.insert(std::make_pair(node_index, node_subs));
867 }
868
869 for (auto map : node_sub_map_) {
870 std::vector<Subgraph> &subgraphs = map.second;
871
872 if (std::any_of(subgraphs.begin(), subgraphs.end(), [&](Subgraph &sub) { return sub.nodes_.empty(); })) {
873 continue;
874 }
875
876 if (!CheckIsParallelSubGraph(subgraphs)) {
877 continue;
878 }
879
880 // init graph device type
881 for (auto &subgraph : subgraphs) {
882 uint32_t head_node_index = subgraph.heads_.front();
883 subgraph.device_ = static_cast<lite::DeviceType>(model_->graph_.all_nodes_.at(head_node_index)->device_type_);
884 if (subgraph.device_ == DT_GPU) {
885 subgraph.thread_ = major_thread_;
886 subgraph.tid_ = 0;
887 } else {
888 subgraph.thread_ = minor_thread_;
889 subgraph.tid_ = 1;
890 }
891 }
892 ConvertSubGraphToModel(&subgraphs);
893 }
894 InitMainGraphDevice(DT_CPU);
895 }
896
SearchSubGraph(const InnerContext * context,Model * model,std::vector<lite::Tensor * > * src_tensors,const std::map<int,OpParameter * > * op_parameters,std::vector<size_t> * output_nodes)897 SearchSubGraph::SearchSubGraph(const InnerContext *context, Model *model, std::vector<lite::Tensor *> *src_tensors,
898 const std::map<int, OpParameter *> *op_parameters, std::vector<size_t> *output_nodes)
899 : output_nodes_(output_nodes), context_(context), src_tensors_(src_tensors), op_parameters_(op_parameters) {
900 model_ = reinterpret_cast<LiteModel *>(model);
901
902 major_dt_ = DT_CPU;
903 minor_dt_ = DT_CPU;
904 if (context_->IsNpuEnabled()) {
905 major_dt_ = DT_NPU;
906 } else if (context_->IsGpuEnabled()) {
907 major_dt_ = DT_GPU;
908 }
909
910 if (major_dt_ == DT_GPU) {
911 major_thread_ = 1;
912 minor_thread_ = context_->thread_num_ - 1;
913 } else if (major_dt_ == DT_CPU) {
914 major_thread_ = UP_DIV(context_->thread_num_, kDefaultSubGraphSize);
915 minor_thread_ = static_cast<size_t>(context_->thread_num_ - major_thread_);
916 }
917 MS_ASSERT(major_thread_ > 0);
918 MS_ASSERT(minor_thread_ > 0);
919
920 InitSearchTensor();
921 return;
922 }
923
InsertParallelNode(uint32_t index,Subgraph * subgraph)924 void SearchSubGraph::InsertParallelNode(uint32_t index, Subgraph *subgraph) {
925 if (subgraph == nullptr) {
926 return;
927 }
928 if (subgraph->search_terminate_) {
929 if (!subgraph->nodes_.empty()) {
930 sub_graphs_.push_back(std::move(*subgraph));
931 }
932 Subgraph new_graph;
933 subgraph = &new_graph;
934 }
935 LiteGraph::Node *node = node_list_[index];
936 // has been searched
937 if (node == nullptr) {
938 return;
939 }
940
941 // if current node is parallel target node
942 if (IsOfflineParallelNode(node->primitive_, node->device_type_)) {
943 // first searched
944 if (subgraph->nodes_.empty()) {
945 subgraph->device_ = static_cast<DeviceType>(node->device_type_);
946 } else {
947 // check pre_device_type equal to current device_type
948 if (subgraph->device_ != static_cast<DeviceType>(node->device_type_)) {
949 return;
950 }
951 }
952 subgraph->nodes_.insert(subgraph->nodes_.begin(), index);
953 node_list_[index] = nullptr;
954 } else {
955 subgraph->search_terminate_ = true;
956 }
957
958 // just deal with parallel target node
959 std::vector<uint32_t> input = node->input_indices_;
960
961 /* remove const node */
962 for (int i = static_cast<int>(input.size()) - 1; i >= 0; i--) {
963 if (tensors_[input[i]].type_ == CONST) {
964 VectorErase(&input, input[i]);
965 }
966 }
967
968 // search to graph to graph input , terminate it.
969 if (std::any_of(input.begin(), input.end(),
970 [&](uint32_t input_index) { return tensors_[input_index].type_ == INPUT; })) {
971 subgraph->search_terminate_ = true;
972 return;
973 }
974
975 // search for next nodes
976 for (uint32_t next : input) {
977 auto next_nodes = tensors_[next].out_nodes_;
978 for (uint32_t next_node : next_nodes) {
979 InsertParallelNode(next_node, subgraph);
980 }
981 }
982 }
CheckSubHeadEnd(Subgraph * sub)983 void SearchSubGraph::CheckSubHeadEnd(Subgraph *sub) {
984 /* head-end node may error after subgraph fusion */
985 /* sub head node check */
986 std::vector<uint32_t> delete_head;
987 for (uint32_t head_node : sub->heads_) {
988 if (std::find(sub->nodes_.begin(), sub->nodes_.end(), head_node) == sub->nodes_.end()) {
989 delete_head.push_back(head_node);
990 continue;
991 }
992 LiteGraph::Node *node = model_->graph_.all_nodes_.at(head_node);
993 std::vector<uint32_t> in_tensors = node->input_indices_;
994 std::vector<uint32_t> in_nodes;
995 for (uint32_t in_t : in_tensors) {
996 in_nodes.insert(in_nodes.begin(), tensors_.at(in_t).out_nodes_.begin(), tensors_.at(in_t).out_nodes_.end());
997 }
998
999 if (in_nodes.empty()) {
1000 continue;
1001 }
1002
1003 bool erase_head = true;
1004 for (uint32_t in_n : in_nodes) {
1005 if (std::find(sub->nodes_.begin(), sub->nodes_.end(), in_n) == sub->nodes_.end()) {
1006 erase_head = false;
1007 break;
1008 }
1009 }
1010 if (erase_head) {
1011 delete_head.push_back(head_node);
1012 }
1013 }
1014 for (uint32_t head : delete_head) {
1015 VectorErase(&sub->heads_, head);
1016 }
1017
1018 /* sub end node check */
1019 std::vector<uint32_t> delete_end;
1020 for (uint32_t end_node : sub->ends_) {
1021 if (std::find(sub->nodes_.begin(), sub->nodes_.end(), end_node) == sub->nodes_.end()) {
1022 delete_end.push_back(end_node);
1023 }
1024 }
1025 for (uint32_t end : delete_end) {
1026 VectorErase(&sub->ends_, end);
1027 }
1028 return;
1029 }
1030
ValidInParallel()1031 bool SearchSubGraph::ValidInParallel() {
1032 LiteGraph::Node *front_node = model_->graph_.all_nodes_.at(0);
1033 if (front_node->quant_type_ != schema::QuantType_QUANT_NONE) {
1034 return false;
1035 }
1036 if (major_dt_ == DT_NPU) {
1037 return false;
1038 }
1039 if (model_->graph_.sub_graphs_.size() > 1) {
1040 return false;
1041 }
1042 if (model_->GetSchemaVersion() != SCHEMA_VERSION::SCHEMA_CUR) {
1043 return false;
1044 }
1045 return true;
1046 }
1047
SubGraphSplit()1048 void SearchSubGraph::SubGraphSplit() {
1049 if (!ValidInParallel()) {
1050 return;
1051 }
1052
1053 UpdateOfflineParallelFlag();
1054 if (offline_parallel_enable_) {
1055 SubGraphSplitByOffLineParallel();
1056 } else {
1057 SubGraphSplitByOutput();
1058 SubGraphSplitByMiddle();
1059 }
1060 return;
1061 }
1062 } // namespace mindspore::lite
1063