• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/delegate/npu/pass/npu_fusion_pass.h"
18 #include <vector>
19 #include "src/delegate/npu/pass/npu_pass_utils.h"
20 #include "src/delegate/npu/npu_converter_utils.h"
21 #include "src/delegate/npu/op/concat_npu.h"
22 #include "src/delegate/npu/op/split_npu.h"
23 #include "src/delegate/npu/op/pad_npu.h"
24 #include "src/delegate/npu/op/strided_slice_npu.h"
25 
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 
29 namespace {
30 constexpr int kNumDims = 4;
31 constexpr int kNumInputSize = 4;
32 }  // namespace
33 
34 namespace mindspore {
CheckFusion(NPUOp * cur_op)35 bool CheckFusion(NPUOp *cur_op) {
36   if (cur_op->in_ops().empty() || cur_op->out_ops().empty()) {
37     return false;
38   }
39   auto pre_flag = std::all_of(cur_op->in_ops().begin(), cur_op->in_ops().end(), [](NPUOp *in_op) {
40     return NPUPassUtils::IsNchw2Nhwc(in_op) && in_op->out_ops().size() == 1;
41   });
42   if (!pre_flag) {
43     return false;
44   }
45   auto post_flag = std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
46                                [](NPUOp *out_op) { return NPUPassUtils::IsNhwc2Nchw(out_op); });
47   return post_flag;
48 }
49 
CheckFormatFusion(NPUOp * cur_op)50 bool CheckFormatFusion(NPUOp *cur_op) {
51   if (cur_op->out_ops().empty()) {
52     return false;
53   }
54   if (NPUPassUtils::IsNhwc2Nchw(cur_op)) {
55     return std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
56                        [](NPUOp *cur_op) { return NPUPassUtils::IsNchw2Nhwc(cur_op); });
57   }
58   if (NPUPassUtils::IsNchw2Nhwc(cur_op)) {
59     return std::all_of(cur_op->out_ops().begin(), cur_op->out_ops().end(),
60                        [](NPUOp *cur_op) { return NPUPassUtils::IsNhwc2Nchw(cur_op); });
61   }
62   return false;
63 }
64 
RemoveAndFreeOp(NPUOp * cur_op)65 void NPUFusionPass::RemoveAndFreeOp(NPUOp *cur_op) {
66   auto itr = find(all_ops_->begin(), all_ops_->end(), cur_op);
67   if (itr != all_ops_->end()) {
68     all_ops_->erase(itr);
69   }
70   delete cur_op;
71 }
72 
UpdatePreOps(NPUOp * cur_op)73 int NPUFusionPass::UpdatePreOps(NPUOp *cur_op) {
74   for (auto in_op : cur_op->in_ops()) {
75     // graph in op
76     if (in_op->in_ops().empty()) {
77       continue;
78     }
79     auto pre_op = in_op->in_ops()[0];
80 
81     auto pre_out_ops = pre_op->out_ops();
82     for (size_t i = 0; i < pre_out_ops.size(); i++) {
83       if (pre_out_ops[i] == in_op) {
84         pre_out_ops[i] = cur_op;
85         break;
86       }
87     }
88     pre_op->set_out_ops(pre_out_ops);
89 
90     auto cur_in_ops = cur_op->in_ops();
91     for (size_t i = 0; i < cur_in_ops.size(); i++) {
92       if (cur_in_ops[i] == in_op) {
93         cur_in_ops[i] = pre_op;
94         break;
95       }
96     }
97     cur_op->set_in_ops(cur_in_ops);
98     RemoveAndFreeOp(in_op);
99   }
100   return RET_OK;
101 }
102 
UpdatePostOps(NPUOp * cur_op)103 int NPUFusionPass::UpdatePostOps(NPUOp *cur_op) {
104   auto cur_out_ops = cur_op->out_ops();
105   for (auto out_op : cur_op->out_ops()) {
106     // graph out op
107     if (out_op->out_ops().empty()) {
108       cur_out_ops.erase(find(cur_out_ops.begin(), cur_out_ops.end(), out_op));
109     } else {
110       auto post_op = out_op->out_ops()[0];
111       auto post_in_ops = post_op->in_ops();
112       for (size_t i = 0; i < post_in_ops.size(); i++) {
113         if (post_in_ops[i] == out_op) {
114           post_in_ops[i] = cur_op;
115           break;
116         }
117       }
118       post_op->set_in_ops(post_in_ops);
119 
120       for (size_t i = 0; i < cur_out_ops.size(); i++) {
121         if (cur_out_ops[i] == out_op) {
122           cur_out_ops[i] = post_op;
123           break;
124         }
125       }
126     }
127     RemoveAndFreeOp(out_op);
128   }
129   cur_op->set_out_ops(cur_out_ops);
130   return RET_OK;
131 }
132 
UpdatePreTensors(NPUOp * cur_op)133 int UpdatePreTensors(NPUOp *cur_op) {
134   auto tensors_vec = NPUPassUtils::GetNonConstInputs(cur_op);
135   for (auto in_op : cur_op->in_ops()) {
136     if (in_op->inputs().empty() || in_op->outputs().empty() || in_op->in_ops().empty()) {
137       MS_LOG(ERROR) << "in_tensors/out_tensors/in_ops is empty.";
138       return RET_ERROR;
139     }
140     mindspore::MSTensor cur_tensor;
141     auto in_tensor = in_op->inputs()[0];
142     auto out_tensor = in_op->outputs()[0];
143     auto pre_op = in_op->in_ops()[0];
144     for (size_t i = 0; i < pre_op->outputs().size(); i++) {
145       if (pre_op->outputs()[i] == in_tensor) {
146         cur_tensor = pre_op->outputs()[i];
147       }
148     }
149     for (size_t i = 0; i < tensors_vec.size(); i++) {
150       if (tensors_vec[i] == out_tensor) {
151         tensors_vec[i] = cur_tensor;
152       }
153     }
154   }
155   // add constant inputs back
156   if (nodes2const_index.find(cur_op->type()) != nodes2const_index.end()) {
157     tensors_vec.resize(cur_op->inputs().size());
158     auto const_index = nodes2const_index[cur_op->type()];
159     for (auto index : const_index) {
160       if (index >= cur_op->inputs().size()) {
161         continue;
162       }
163       tensors_vec[index] = cur_op->inputs()[index];
164     }
165   }
166   cur_op->set_inputs(tensors_vec);
167   return RET_OK;
168 }
169 
NodeWithNhwc2nchw2nhwcOutput(NPUOp * cur_op)170 bool NodeWithNhwc2nchw2nhwcOutput(NPUOp *cur_op) {
171   auto out_ops = cur_op->out_ops();
172   if (out_ops.empty()) {
173     return false;
174   }
175   bool all_out_ops_transpose = std::all_of(out_ops.begin(), out_ops.end(), [](NPUOp *op) {
176     return op->type() == schema::PrimitiveType_Transpose && op->out_ops().size() == 1 &&
177            op->out_ops()[0]->type() == schema::PrimitiveType_Transpose && op->out_ops()[0]->out_ops().empty();
178   });
179   return all_out_ops_transpose;
180 }
181 
UpdatePostTensors(NPUOp * cur_op)182 int UpdatePostTensors(NPUOp *cur_op) {
183   auto tensor = cur_op->outputs()[0];
184 
185   // in case: node->nh2nc->nc2nh(graph output) --->>> node->nc2nh, node out_tensor should be put to nc2nh out tensors
186   auto out_ops = cur_op->out_ops();
187   if (NodeWithNhwc2nchw2nhwcOutput(cur_op)) {
188     std::vector<MSTensor> outputs;
189     for (auto i = 0; i < out_ops.size(); ++i) {
190       auto ori_out_tensor = cur_op->outputs()[i];
191       auto nc_tensor = out_ops[i]->outputs()[0];
192       outputs.push_back(nc_tensor);
193       auto post_post_op = out_ops[i]->out_ops()[0];
194       post_post_op->set_inputs({nc_tensor});
195       post_post_op->set_outputs({ori_out_tensor});
196     }
197     cur_op->set_outputs(outputs);
198     return RET_OK;
199   }
200 
201   auto nhwc_shape = tensor.Shape();
202   if (nhwc_shape.size() < kNumDims) {
203     MS_LOG(ERROR) << "nhwc_shape < " << kNumDims;
204     return RET_ERROR;
205   }
206   tensor.SetShape({nhwc_shape[NHWC_N], nhwc_shape[NHWC_C], nhwc_shape[NHWC_H], nhwc_shape[NHWC_W]});
207   for (auto out_op : cur_op->out_ops()) {
208     auto out_tensor = out_op->outputs()[0];
209     if (out_op->out_ops().empty()) {
210       cur_op->set_outputs({out_op->outputs()[0]});
211     }
212     for (auto post_op : out_op->out_ops()) {
213       auto tensors_vec = post_op->inputs();
214       for (int i = 0; i < tensors_vec.size(); i++) {
215         if (tensors_vec[i] == out_tensor) {
216           tensors_vec[i] = tensor;
217         }
218       }
219       post_op->set_inputs(tensors_vec);
220     }
221   }
222   return RET_OK;
223 }
224 
UpdateOp(NPUOp * cur_op)225 int NPUFusionPass::UpdateOp(NPUOp *cur_op) {
226   if (cur_op == nullptr) {
227     MS_LOG(ERROR) << "kernel is nullptr.";
228     return RET_ERROR;
229   }
230   auto ret = UpdatePreTensors(cur_op);
231   if (ret != RET_OK) {
232     MS_LOG(ERROR) << "UpdatePreTensors failed.";
233     return RET_ERROR;
234   }
235   ret = UpdatePostTensors(cur_op);
236   if (ret != RET_OK) {
237     MS_LOG(ERROR) << "UpdatePostTensors failed.";
238     return RET_ERROR;
239   }
240   ret = UpdatePreOps(cur_op);
241   if (ret != RET_OK) {
242     MS_LOG(ERROR) << "UpdatePreOps failed.";
243     return RET_ERROR;
244   }
245   ret = UpdatePostOps(cur_op);
246   if (ret != RET_OK) {
247     MS_LOG(ERROR) << "UpdatePostOps failed.";
248     return RET_ERROR;
249   }
250   return RET_OK;
251 }
252 
CommonFusion(NPUOp * cur_op)253 int NPUFusionPass::CommonFusion(NPUOp *cur_op) {
254   if (cur_op == nullptr) {
255     return RET_ERROR;
256   }
257   auto ret = UpdateOp(cur_op);
258   if (ret != RET_OK) {
259     MS_LOG(ERROR) << "UpdateOp failed.";
260     return RET_ERROR;
261   }
262   return RET_OK;
263 }
264 
ConcatFusion(NPUOp * cur_op)265 int NPUFusionPass::ConcatFusion(NPUOp *cur_op) {
266   if (cur_op == nullptr) {
267     return RET_ERROR;
268   }
269   int ret = UpdateOp(cur_op);
270   if (ret != RET_OK) {
271     MS_LOG(ERROR) << "UpdateOp failed.";
272     return ret;
273   }
274   if (cur_op->type() != schema::PrimitiveType_Concat) {
275     return RET_ERROR;
276   }
277   auto concat_op = static_cast<ConcatNPUOp *>(cur_op);
278   ret = concat_op->HandleAxis();
279   if (ret != RET_OK) {
280     MS_LOG(ERROR) << "HandleAxis failed.";
281     return ret;
282   }
283   return RET_OK;
284 }
285 
SplitFusion(NPUOp * cur_op)286 int NPUFusionPass::SplitFusion(NPUOp *cur_op) {
287   if (cur_op == nullptr) {
288     return RET_ERROR;
289   }
290   int ret = UpdateOp(cur_op);
291   if (ret != RET_OK) {
292     MS_LOG(ERROR) << "UpdateOp failed.";
293     return RET_ERROR;
294   }
295   if (cur_op->type() != schema::PrimitiveType_Split) {
296     return RET_ERROR;
297   }
298   auto split_op = static_cast<SplitNPUOp *>(cur_op);
299   ret = split_op->HandleAxis();
300   if (ret != RET_OK) {
301     MS_LOG(ERROR) << "HandleAxis failed.";
302     return ret;
303   }
304   return RET_OK;
305 }
306 
PadFusion(NPUOp * cur_op)307 int NPUFusionPass::PadFusion(NPUOp *cur_op) {
308   if (cur_op == nullptr) {
309     return RET_ERROR;
310   }
311   int ret = UpdateOp(cur_op);
312   if (ret != RET_OK) {
313     MS_LOG(ERROR) << "UpdateOp failed.";
314     return ret;
315   }
316   if (cur_op->type() != schema::PrimitiveType_PadFusion) {
317     return RET_ERROR;
318   }
319   auto pad_op = static_cast<PadNPUOp *>(cur_op);
320   ret = pad_op->HandleAxis();
321   if (ret != RET_OK) {
322     MS_LOG(ERROR) << "HandleAxis failed.";
323     return ret;
324   }
325   return RET_OK;
326 }
327 
StridedSliceFusion(NPUOp * cur_op)328 int NPUFusionPass::StridedSliceFusion(NPUOp *cur_op) {
329   // basic requirement: input is nhwc 4d
330   if (cur_op == nullptr) {
331     return RET_ERROR;
332   }
333   int ret = UpdateOp(cur_op);
334   if (ret != RET_OK) {
335     MS_LOG(ERROR) << "UpdateOp failed.";
336     return ret;
337   }
338   if (cur_op->inputs().size() < kNumInputSize) {
339     MS_LOG(ERROR) << "in tensors size < " << kNumInputSize;
340     return RET_ERROR;
341   }
342   if (cur_op->type() != schema::PrimitiveType_StridedSlice) {
343     return RET_ERROR;
344   }
345   auto begin_tensor = cur_op->inputs().at(BEGIN_INDEX);
346   int *begin = reinterpret_cast<int *>(begin_tensor.MutableData());
347   MS_ASSERT(begin);
348   (void)NPUPassUtils::AssistDataNHWC2NCHW(begin, 1);
349   auto end_tensor = cur_op->inputs().at(END_INDEX);
350   int *end = reinterpret_cast<int *>(end_tensor.MutableData());
351   MS_ASSERT(end);
352   NPUPassUtils::AssistDataNHWC2NCHW(end, 1);
353   auto stride_tensor = cur_op->inputs().at(STRIDE_INDEX);
354   if (cur_op->inputs().size() == ONNX_INPUT_SIZE) {
355     stride_tensor = cur_op->inputs().at(ONNX_STRIDE_INDEX);
356   }
357   int *stride = reinterpret_cast<int *>(stride_tensor.MutableData());
358   MS_ASSERT(stride);
359   NPUPassUtils::AssistDataNHWC2NCHW(stride, 1);
360 
361   auto stride_slice_op = static_cast<StridedSliceNPUOp *>(cur_op);
362   ret = stride_slice_op->HandleAxis();
363   if (ret != RET_OK) {
364     MS_LOG(ERROR) << "HandleAxis failed.";
365     return ret;
366   }
367   return RET_OK;
368 }
369 
FormatFusion(NPUOp * cur_op)370 int NPUFusionPass::FormatFusion(NPUOp *cur_op) {
371   if (cur_op == nullptr) {
372     return RET_ERROR;
373   }
374   auto is_input_op = cur_op->in_ops().empty();
375   NPUOp *pre_op = nullptr;
376   if (!is_input_op) {
377     pre_op = cur_op->in_ops()[0];
378   }
379   auto in_tensor = cur_op->inputs()[0];
380   std::vector<NPUOp *> pre_insert_ops;
381   for (const auto &trans_op : cur_op->out_ops()) {
382     if (trans_op->out_ops().empty() && !is_input_op) {
383       // cur_op is a trans cur_op, it's input cur_op num and input tensor num must be 1
384       cur_op->in_ops()[0]->set_outputs({trans_op->outputs()[0]});
385       // in fp16 mode, tensor data type fp16 need to be changed back.
386       auto tensor = cur_op->in_ops()[0]->outputs()[0];
387       if (tensor.DataType() == DataType::kNumberTypeFloat16) {
388         tensor.SetDataType(DataType::kNumberTypeFloat32);
389       }
390     }
391     for (const auto &post_op : trans_op->out_ops()) {
392       // update tensor
393       auto tensors_vec = post_op->inputs();
394       for (size_t i = 0; i < tensors_vec.size(); i++) {
395         if (tensors_vec[i] == trans_op->outputs()[0]) {
396           tensors_vec[i] = in_tensor;
397           break;
398         }
399       }
400       post_op->set_inputs(tensors_vec);
401 
402       // update op
403       auto post_in_ops = post_op->in_ops();
404       for (size_t i = 0; i < post_in_ops.size(); i++) {
405         if (post_in_ops[i] == trans_op) {
406           if (is_input_op) {
407             post_in_ops.erase(post_in_ops.begin() + i);
408           } else {
409             post_in_ops[i] = pre_op;
410           }
411           break;
412         }
413       }
414       post_op->set_in_ops(post_in_ops);
415       pre_insert_ops.push_back(post_op);
416     }
417     RemoveAndFreeOp(trans_op);
418   }
419   if (!is_input_op) {
420     auto pre_out_ops = pre_op->out_ops();
421     size_t cur_op_index = 0;
422     for (size_t index = 0; index < pre_out_ops.size(); index++) {
423       if (pre_out_ops[index] == cur_op) {
424         pre_out_ops.erase(pre_out_ops.begin() + index);
425         cur_op_index = index;
426       } else {
427         auto tensors_vec = pre_out_ops[index]->inputs();
428         for (size_t i = 0; i < tensors_vec.size(); i++) {
429           if (tensors_vec[i] == in_tensor) {
430             tensors_vec[i] = pre_op->outputs()[0];
431             break;
432           }
433         }
434         pre_out_ops[index]->set_inputs(tensors_vec);
435       }
436     }
437     pre_out_ops.insert(pre_out_ops.begin() + cur_op_index, pre_insert_ops.begin(), pre_insert_ops.end());
438     pre_op->set_out_ops(pre_out_ops);
439   }
440   RemoveAndFreeOp(cur_op);
441   return RET_OK;
442 }
443 
Run(NPUGraph * subgraph)444 int NPUFusionPass::Run(NPUGraph *subgraph) {
445   all_ops_ = subgraph->GetOps();
446   for (size_t i = 0; i < all_ops_->size(); i++) {
447     auto cur_op = (*all_ops_)[i];
448     auto ret = RET_OK;
449     if (CheckFusion(cur_op)) {
450       switch (cur_op->type()) {
451         case schema::PrimitiveType_Split:
452           i -= cur_op->in_ops().size();
453           ret = SplitFusion(cur_op);
454           continue;
455         case schema::PrimitiveType_Concat:
456           i -= cur_op->in_ops().size();
457           ret = ConcatFusion(cur_op);
458           continue;
459         case schema::PrimitiveType_PadFusion:
460           i -= cur_op->in_ops().size();
461           ret = PadFusion(cur_op);
462           continue;
463         case schema::PrimitiveType_StridedSlice:
464           i -= cur_op->in_ops().size();
465           ret = StridedSliceFusion(cur_op);
466           continue;
467         case schema::PrimitiveType_AddFusion:
468         case schema::PrimitiveType_MulFusion:
469         case schema::PrimitiveType_DivFusion:
470         case schema::PrimitiveType_Activation:
471         case schema::PrimitiveType_Eltwise:
472           i -= cur_op->in_ops().size();
473           ret = CommonFusion(cur_op);
474           continue;
475         default:
476           continue;
477       }
478     }
479     if (ret != RET_OK) {
480       MS_LOG(ERROR) << "Fusion failed.";
481       return RET_ERROR;
482     }
483   }
484   for (size_t i = 0; i < all_ops_->size(); ++i) {
485     auto cur_op = (*all_ops_)[i];
486     if (CheckFormatFusion(cur_op)) {
487       i--;
488       auto ret = FormatFusion(cur_op);
489       if (ret != RET_OK) {
490         MS_LOG(ERROR) << "FormatFusion failed.";
491         return RET_ERROR;
492       }
493     }
494   }
495   return RET_OK;
496 }
497 }  // namespace mindspore
498