1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/delegate/npu/pass/npu_fusion_pass.h"
18 #include <vector>
19 #include "src/delegate/npu/pass/npu_pass_utils.h"
20 #include "src/delegate/npu/npu_converter_utils.h"
21 #include "src/delegate/npu/op/concat_npu.h"
22 #include "src/delegate/npu/op/split_npu.h"
23 #include "src/delegate/npu/op/pad_npu.h"
24 #include "src/delegate/npu/op/strided_slice_npu.h"
25
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28
29 namespace {
30 constexpr int kNumDims = 4;
31 constexpr int kNumInputSize = 4;
32 } // namespace
33
34 namespace mindspore {
CheckFusion(NPUOp * cur_op)35 bool CheckFusion(NPUOp *cur_op) {
36 if (cur_op->in_ops().empty() || cur_op->out_ops().empty()) {
37 return false;
38 }
39 auto pre_flag = std::all_of(cur_op->in_ops().begin(), cur_op->in_ops().end(), [](NPUOp *in_op) {
40 return NPUPassUtils::IsNchw2Nhwc(in_op) && in_op->out_ops().size() == 1;
41 });
42 if (!pre_flag) {
43 return false;
44 }
45 auto post_flag = std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
46 [](NPUOp *out_op) { return NPUPassUtils::IsNhwc2Nchw(out_op); });
47 return post_flag;
48 }
49
CheckFormatFusion(NPUOp * cur_op)50 bool CheckFormatFusion(NPUOp *cur_op) {
51 if (cur_op->out_ops().empty()) {
52 return false;
53 }
54 if (NPUPassUtils::IsNhwc2Nchw(cur_op)) {
55 return std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
56 [](NPUOp *cur_op) { return NPUPassUtils::IsNchw2Nhwc(cur_op); });
57 }
58 if (NPUPassUtils::IsNchw2Nhwc(cur_op)) {
59 return std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
60 [](NPUOp *cur_op) { return NPUPassUtils::IsNhwc2Nchw(cur_op); });
61 }
62 return false;
63 }
64
RemoveAndFreeOp(NPUOp * cur_op)65 void NPUFusionPass::RemoveAndFreeOp(NPUOp *cur_op) {
66 auto itr = find(all_ops_->begin(), all_ops_->end(), cur_op);
67 if (itr != all_ops_->end()) {
68 all_ops_->erase(itr);
69 }
70 delete cur_op;
71 }
72
UpdatePreOps(NPUOp * cur_op)73 int NPUFusionPass::UpdatePreOps(NPUOp *cur_op) {
74 for (auto in_op : cur_op->in_ops()) {
75 // graph in op
76 if (in_op->in_ops().empty()) {
77 continue;
78 }
79 auto pre_op = in_op->in_ops()[0];
80
81 auto pre_out_ops = pre_op->out_ops();
82 for (size_t i = 0; i < pre_out_ops.size(); i++) {
83 if (pre_out_ops[i] == in_op) {
84 pre_out_ops[i] = cur_op;
85 break;
86 }
87 }
88 pre_op->set_out_ops(pre_out_ops);
89
90 auto cur_in_ops = cur_op->in_ops();
91 for (size_t i = 0; i < cur_in_ops.size(); i++) {
92 if (cur_in_ops[i] == in_op) {
93 cur_in_ops[i] = pre_op;
94 break;
95 }
96 }
97 cur_op->set_in_ops(cur_in_ops);
98 RemoveAndFreeOp(in_op);
99 }
100 return RET_OK;
101 }
102
UpdatePostOps(NPUOp * cur_op)103 int NPUFusionPass::UpdatePostOps(NPUOp *cur_op) {
104 auto cur_out_ops = cur_op->out_ops();
105 for (auto out_op : cur_op->out_ops()) {
106 // graph out op
107 if (out_op->out_ops().empty()) {
108 cur_out_ops.erase(find(cur_out_ops.begin(), cur_out_ops.end(), out_op));
109 } else {
110 auto post_op = out_op->out_ops()[0];
111 auto post_in_ops = post_op->in_ops();
112 for (size_t i = 0; i < post_in_ops.size(); i++) {
113 if (post_in_ops[i] == out_op) {
114 post_in_ops[i] = cur_op;
115 break;
116 }
117 }
118 post_op->set_in_ops(post_in_ops);
119
120 for (size_t i = 0; i < cur_out_ops.size(); i++) {
121 if (cur_out_ops[i] == out_op) {
122 cur_out_ops[i] = post_op;
123 break;
124 }
125 }
126 }
127 RemoveAndFreeOp(out_op);
128 }
129 cur_op->set_out_ops(cur_out_ops);
130 return RET_OK;
131 }
132
UpdatePreTensors(NPUOp * cur_op)133 int UpdatePreTensors(NPUOp *cur_op) {
134 auto tensors_vec = NPUPassUtils::GetNonConstInputs(cur_op);
135 for (auto in_op : cur_op->in_ops()) {
136 if (in_op->inputs().empty() || in_op->outputs().empty() || in_op->in_ops().empty()) {
137 MS_LOG(ERROR) << "in_tensors/out_tensors/in_ops is empty.";
138 return RET_ERROR;
139 }
140 mindspore::MSTensor cur_tensor;
141 auto in_tensor = in_op->inputs()[0];
142 auto out_tensor = in_op->outputs()[0];
143 auto pre_op = in_op->in_ops()[0];
144 for (size_t i = 0; i < pre_op->outputs().size(); i++) {
145 if (pre_op->outputs()[i] == in_tensor) {
146 cur_tensor = pre_op->outputs()[i];
147 }
148 }
149 for (size_t i = 0; i < tensors_vec.size(); i++) {
150 if (tensors_vec[i] == out_tensor) {
151 tensors_vec[i] = cur_tensor;
152 }
153 }
154 }
155 // add constant inputs back
156 if (nodes2const_index.find(cur_op->type()) != nodes2const_index.end()) {
157 tensors_vec.resize(cur_op->inputs().size());
158 auto const_index = nodes2const_index[cur_op->type()];
159 for (auto index : const_index) {
160 if (index >= cur_op->inputs().size()) {
161 continue;
162 }
163 tensors_vec[index] = cur_op->inputs()[index];
164 }
165 }
166 cur_op->set_inputs(tensors_vec);
167 return RET_OK;
168 }
169
NodeWithNhwc2nchw2nhwcOutput(NPUOp * cur_op)170 bool NodeWithNhwc2nchw2nhwcOutput(NPUOp *cur_op) {
171 auto out_ops = cur_op->out_ops();
172 if (out_ops.empty()) {
173 return false;
174 }
175 bool all_out_ops_transpose = std::all_of(out_ops.begin(), out_ops.end(), [](NPUOp *op) {
176 return op->type() == schema::PrimitiveType_Transpose && op->out_ops().size() == 1 &&
177 op->out_ops()[0]->type() == schema::PrimitiveType_Transpose && op->out_ops()[0]->out_ops().empty();
178 });
179 return all_out_ops_transpose;
180 }
181
UpdatePostTensors(NPUOp * cur_op)182 int UpdatePostTensors(NPUOp *cur_op) {
183 auto tensor = cur_op->outputs()[0];
184
185 // in case: node->nh2nc->nc2nh(graph output) --->>> node->nc2nh, node out_tensor should be put to nc2nh out tensors
186 auto out_ops = cur_op->out_ops();
187 if (NodeWithNhwc2nchw2nhwcOutput(cur_op)) {
188 std::vector<MSTensor> outputs;
189 for (auto i = 0; i < out_ops.size(); ++i) {
190 auto ori_out_tensor = cur_op->outputs()[i];
191 auto nc_tensor = out_ops[i]->outputs()[0];
192 outputs.push_back(nc_tensor);
193 auto post_post_op = out_ops[i]->out_ops()[0];
194 post_post_op->set_inputs({nc_tensor});
195 post_post_op->set_outputs({ori_out_tensor});
196 }
197 cur_op->set_outputs(outputs);
198 return RET_OK;
199 }
200
201 auto nhwc_shape = tensor.Shape();
202 if (nhwc_shape.size() < kNumDims) {
203 MS_LOG(ERROR) << "nhwc_shape < " << kNumDims;
204 return RET_ERROR;
205 }
206 tensor.SetShape({nhwc_shape[NHWC_N], nhwc_shape[NHWC_C], nhwc_shape[NHWC_H], nhwc_shape[NHWC_W]});
207 for (auto out_op : cur_op->out_ops()) {
208 auto out_tensor = out_op->outputs()[0];
209 if (out_op->out_ops().empty()) {
210 cur_op->set_outputs({out_op->outputs()[0]});
211 }
212 for (auto post_op : out_op->out_ops()) {
213 auto tensors_vec = post_op->inputs();
214 for (int i = 0; i < tensors_vec.size(); i++) {
215 if (tensors_vec[i] == out_tensor) {
216 tensors_vec[i] = tensor;
217 }
218 }
219 post_op->set_inputs(tensors_vec);
220 }
221 }
222 return RET_OK;
223 }
224
UpdateOp(NPUOp * cur_op)225 int NPUFusionPass::UpdateOp(NPUOp *cur_op) {
226 if (cur_op == nullptr) {
227 MS_LOG(ERROR) << "kernel is nullptr.";
228 return RET_ERROR;
229 }
230 auto ret = UpdatePreTensors(cur_op);
231 if (ret != RET_OK) {
232 MS_LOG(ERROR) << "UpdatePreTensors failed.";
233 return RET_ERROR;
234 }
235 ret = UpdatePostTensors(cur_op);
236 if (ret != RET_OK) {
237 MS_LOG(ERROR) << "UpdatePostTensors failed.";
238 return RET_ERROR;
239 }
240 ret = UpdatePreOps(cur_op);
241 if (ret != RET_OK) {
242 MS_LOG(ERROR) << "UpdatePreOps failed.";
243 return RET_ERROR;
244 }
245 ret = UpdatePostOps(cur_op);
246 if (ret != RET_OK) {
247 MS_LOG(ERROR) << "UpdatePostOps failed.";
248 return RET_ERROR;
249 }
250 return RET_OK;
251 }
252
CommonFusion(NPUOp * cur_op)253 int NPUFusionPass::CommonFusion(NPUOp *cur_op) {
254 if (cur_op == nullptr) {
255 return RET_ERROR;
256 }
257 auto ret = UpdateOp(cur_op);
258 if (ret != RET_OK) {
259 MS_LOG(ERROR) << "UpdateOp failed.";
260 return RET_ERROR;
261 }
262 return RET_OK;
263 }
264
ConcatFusion(NPUOp * cur_op)265 int NPUFusionPass::ConcatFusion(NPUOp *cur_op) {
266 if (cur_op == nullptr) {
267 return RET_ERROR;
268 }
269 int ret = UpdateOp(cur_op);
270 if (ret != RET_OK) {
271 MS_LOG(ERROR) << "UpdateOp failed.";
272 return ret;
273 }
274 if (cur_op->type() != schema::PrimitiveType_Concat) {
275 return RET_ERROR;
276 }
277 auto concat_op = static_cast<ConcatNPUOp *>(cur_op);
278 ret = concat_op->HandleAxis();
279 if (ret != RET_OK) {
280 MS_LOG(ERROR) << "HandleAxis failed.";
281 return ret;
282 }
283 return RET_OK;
284 }
285
SplitFusion(NPUOp * cur_op)286 int NPUFusionPass::SplitFusion(NPUOp *cur_op) {
287 if (cur_op == nullptr) {
288 return RET_ERROR;
289 }
290 int ret = UpdateOp(cur_op);
291 if (ret != RET_OK) {
292 MS_LOG(ERROR) << "UpdateOp failed.";
293 return RET_ERROR;
294 }
295 if (cur_op->type() != schema::PrimitiveType_Split) {
296 return RET_ERROR;
297 }
298 auto split_op = static_cast<SplitNPUOp *>(cur_op);
299 ret = split_op->HandleAxis();
300 if (ret != RET_OK) {
301 MS_LOG(ERROR) << "HandleAxis failed.";
302 return ret;
303 }
304 return RET_OK;
305 }
306
PadFusion(NPUOp * cur_op)307 int NPUFusionPass::PadFusion(NPUOp *cur_op) {
308 if (cur_op == nullptr) {
309 return RET_ERROR;
310 }
311 int ret = UpdateOp(cur_op);
312 if (ret != RET_OK) {
313 MS_LOG(ERROR) << "UpdateOp failed.";
314 return ret;
315 }
316 if (cur_op->type() != schema::PrimitiveType_PadFusion) {
317 return RET_ERROR;
318 }
319 auto pad_op = static_cast<PadNPUOp *>(cur_op);
320 ret = pad_op->HandleAxis();
321 if (ret != RET_OK) {
322 MS_LOG(ERROR) << "HandleAxis failed.";
323 return ret;
324 }
325 return RET_OK;
326 }
327
StridedSliceFusion(NPUOp * cur_op)328 int NPUFusionPass::StridedSliceFusion(NPUOp *cur_op) {
329 // basic requirement: input is nhwc 4d
330 if (cur_op == nullptr) {
331 return RET_ERROR;
332 }
333 int ret = UpdateOp(cur_op);
334 if (ret != RET_OK) {
335 MS_LOG(ERROR) << "UpdateOp failed.";
336 return ret;
337 }
338 if (cur_op->inputs().size() < kNumInputSize) {
339 MS_LOG(ERROR) << "in tensors size < " << kNumInputSize;
340 return RET_ERROR;
341 }
342 if (cur_op->type() != schema::PrimitiveType_StridedSlice) {
343 return RET_ERROR;
344 }
345 auto begin_tensor = cur_op->inputs().at(BEGIN_INDEX);
346 int *begin = reinterpret_cast<int *>(begin_tensor.MutableData());
347 MS_ASSERT(begin);
348 (void)NPUPassUtils::AssistDataNHWC2NCHW(begin, 1);
349 auto end_tensor = cur_op->inputs().at(END_INDEX);
350 int *end = reinterpret_cast<int *>(end_tensor.MutableData());
351 MS_ASSERT(end);
352 NPUPassUtils::AssistDataNHWC2NCHW(end, 1);
353 auto stride_tensor = cur_op->inputs().at(STRIDE_INDEX);
354 if (cur_op->inputs().size() == ONNX_INPUT_SIZE) {
355 stride_tensor = cur_op->inputs().at(ONNX_STRIDE_INDEX);
356 }
357 int *stride = reinterpret_cast<int *>(stride_tensor.MutableData());
358 MS_ASSERT(stride);
359 NPUPassUtils::AssistDataNHWC2NCHW(stride, 1);
360
361 auto stride_slice_op = static_cast<StridedSliceNPUOp *>(cur_op);
362 ret = stride_slice_op->HandleAxis();
363 if (ret != RET_OK) {
364 MS_LOG(ERROR) << "HandleAxis failed.";
365 return ret;
366 }
367 return RET_OK;
368 }
369
FormatFusion(NPUOp * cur_op)370 int NPUFusionPass::FormatFusion(NPUOp *cur_op) {
371 if (cur_op == nullptr) {
372 return RET_ERROR;
373 }
374 auto is_input_op = cur_op->in_ops().empty();
375 NPUOp *pre_op = nullptr;
376 if (!is_input_op) {
377 pre_op = cur_op->in_ops()[0];
378 }
379 auto in_tensor = cur_op->inputs()[0];
380 std::vector<NPUOp *> pre_insert_ops;
381 for (const auto &trans_op : cur_op->out_ops()) {
382 if (trans_op->out_ops().empty() && !is_input_op) {
383 // cur_op is a trans cur_op, it's input cur_op num and input tensor num must be 1
384 cur_op->in_ops()[0]->set_outputs({trans_op->outputs()[0]});
385 // in fp16 mode, tensor data type fp16 need to be changed back.
386 auto tensor = cur_op->in_ops()[0]->outputs()[0];
387 if (tensor.DataType() == DataType::kNumberTypeFloat16) {
388 tensor.SetDataType(DataType::kNumberTypeFloat32);
389 }
390 }
391 for (const auto &post_op : trans_op->out_ops()) {
392 // update tensor
393 auto tensors_vec = post_op->inputs();
394 for (size_t i = 0; i < tensors_vec.size(); i++) {
395 if (tensors_vec[i] == trans_op->outputs()[0]) {
396 tensors_vec[i] = in_tensor;
397 break;
398 }
399 }
400 post_op->set_inputs(tensors_vec);
401
402 // update op
403 auto post_in_ops = post_op->in_ops();
404 for (size_t i = 0; i < post_in_ops.size(); i++) {
405 if (post_in_ops[i] == trans_op) {
406 if (is_input_op) {
407 post_in_ops.erase(post_in_ops.begin() + i);
408 } else {
409 post_in_ops[i] = pre_op;
410 }
411 break;
412 }
413 }
414 post_op->set_in_ops(post_in_ops);
415 pre_insert_ops.push_back(post_op);
416 }
417 RemoveAndFreeOp(trans_op);
418 }
419 if (!is_input_op) {
420 auto pre_out_ops = pre_op->out_ops();
421 size_t cur_op_index = 0;
422 for (size_t index = 0; index < pre_out_ops.size(); index++) {
423 if (pre_out_ops[index] == cur_op) {
424 pre_out_ops.erase(pre_out_ops.begin() + index);
425 cur_op_index = index;
426 } else {
427 auto tensors_vec = pre_out_ops[index]->inputs();
428 for (size_t i = 0; i < tensors_vec.size(); i++) {
429 if (tensors_vec[i] == in_tensor) {
430 tensors_vec[i] = pre_op->outputs()[0];
431 break;
432 }
433 }
434 pre_out_ops[index]->set_inputs(tensors_vec);
435 }
436 }
437 pre_out_ops.insert(pre_out_ops.begin() + cur_op_index, pre_insert_ops.begin(), pre_insert_ops.end());
438 pre_op->set_out_ops(pre_out_ops);
439 }
440 RemoveAndFreeOp(cur_op);
441 return RET_OK;
442 }
443
Run(NPUGraph * subgraph)444 int NPUFusionPass::Run(NPUGraph *subgraph) {
445 all_ops_ = subgraph->GetOps();
446 for (size_t i = 0; i < all_ops_->size(); i++) {
447 auto cur_op = (*all_ops_)[i];
448 auto ret = RET_OK;
449 if (CheckFusion(cur_op)) {
450 switch (cur_op->type()) {
451 case schema::PrimitiveType_Split:
452 i -= cur_op->in_ops().size();
453 ret = SplitFusion(cur_op);
454 continue;
455 case schema::PrimitiveType_Concat:
456 i -= cur_op->in_ops().size();
457 ret = ConcatFusion(cur_op);
458 continue;
459 case schema::PrimitiveType_PadFusion:
460 i -= cur_op->in_ops().size();
461 ret = PadFusion(cur_op);
462 continue;
463 case schema::PrimitiveType_StridedSlice:
464 i -= cur_op->in_ops().size();
465 ret = StridedSliceFusion(cur_op);
466 continue;
467 case schema::PrimitiveType_AddFusion:
468 case schema::PrimitiveType_MulFusion:
469 case schema::PrimitiveType_DivFusion:
470 case schema::PrimitiveType_Activation:
471 case schema::PrimitiveType_Eltwise:
472 i -= cur_op->in_ops().size();
473 ret = CommonFusion(cur_op);
474 continue;
475 default:
476 continue;
477 }
478 }
479 if (ret != RET_OK) {
480 MS_LOG(ERROR) << "Fusion failed.";
481 return RET_ERROR;
482 }
483 }
484 for (size_t i = 0; i < all_ops_->size(); ++i) {
485 auto cur_op = (*all_ops_)[i];
486 if (CheckFormatFusion(cur_op)) {
487 i--;
488 auto ret = FormatFusion(cur_op);
489 if (ret != RET_OK) {
490 MS_LOG(ERROR) << "FormatFusion failed.";
491 return RET_ERROR;
492 }
493 }
494 }
495 return RET_OK;
496 }
497 } // namespace mindspore
498