1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/litert/pass/format_pass/transpose_strategy.h"
18 #include "nnacl/op_base.h"
19 #include "nnacl/arg_min_max_parameter.h"
20 #include "nnacl/concat_parameter.h"
21 #include "nnacl/crop_parameter.h"
22 #include "nnacl/softmax_parameter.h"
23 #include "nnacl/split_parameter.h"
24 #include "nnacl/squeeze_parameter.h"
25 #include "nnacl/stack_parameter.h"
26 #include "nnacl/unsqueeze_parameter.h"
27 #include "nnacl/unstack_parameter.h"
28 #include "nnacl/slice_parameter.h"
29 #include "nnacl/strided_slice_parameter.h"
30
31 namespace mindspore::lite::pass {
32 static const std::set<schema::PrimitiveType> arithmetic_kernel_lists = {
33 schema::PrimitiveType_AddFusion, schema::PrimitiveType_AddN,
34 schema::PrimitiveType_DivFusion, schema::PrimitiveType_Eltwise,
35 schema::PrimitiveType_Equal, schema::PrimitiveType_FloorDiv,
36 schema::PrimitiveType_FloorMod, schema::PrimitiveType_Greater,
37 schema::PrimitiveType_GreaterEqual, schema::PrimitiveType_Less,
38 schema::PrimitiveType_LessEqual, schema::PrimitiveType_LogicalAnd,
39 schema::PrimitiveType_LogicalOr, schema::PrimitiveType_Maximum,
40 schema::PrimitiveType_Minimum, schema::PrimitiveType_Mod,
41 schema::PrimitiveType_MulFusion, schema::PrimitiveType_NotEqual,
42 schema::PrimitiveType_RealDiv, schema::PrimitiveType_SquaredDifference,
43 schema::PrimitiveType_SubFusion,
44 };
45
GetTransCount(const std::vector<kernel::KernelExec * > & kernels,TransInfoPair * trans_info)46 size_t TransposeStrategy::GetTransCount(const std::vector<kernel::KernelExec *> &kernels, TransInfoPair *trans_info) {
47 size_t count = 0;
48 for (const auto &in_kernel : kernels) {
49 TransInfoPair tmp_trans;
50 if (GetTransposeInfo(in_kernel, &tmp_trans) != RET_OK) {
51 continue;
52 }
53 if (IsNoneTranspose(*trans_info)) {
54 *trans_info = tmp_trans;
55 count++;
56 } else if (IsSameTranspose(*trans_info, tmp_trans)) {
57 count++;
58 } else {
59 continue;
60 }
61 }
62 return count;
63 }
64
CheckInTensorsShape(const kernel::KernelExec * kernel,const Format & runtime_format)65 bool CheckInTensorsShape(const kernel::KernelExec *kernel, const Format &runtime_format) {
66 // If teh fusion is valid, kernel will be executed in runtime_format.
67 // Only check arithmetic (two input) kernel input tensors.
68 // If broadcast for various formats is supported, this function can be deleted.
69 // eg: tensor 1 shape(1, 128, 24, 24), tensor 2 shape(1, 128, 1, 1), the NC4HW4 format is not supported now.
70 if (arithmetic_kernel_lists.find(kernel::SchemaType(kernel->type())) == arithmetic_kernel_lists.end()) {
71 return true;
72 }
73 for (const auto &in_tensor : kernel->in_tensors()) {
74 const auto &in_shape = in_tensor->shape();
75 if (std::any_of(in_shape.begin(), in_shape.end(), [](const int &dim) { return dim == -1; })) {
76 return false;
77 }
78 }
79 const auto &in0_shape = kernel->in_tensors().at(0)->shape();
80 if (runtime_format == NHWC || runtime_format == NCHW) {
81 // For NCHW or NHWC format, the shape.size must be equal.
82 if (std::any_of(kernel->in_tensors().begin(), kernel->in_tensors().end(),
83 [&in0_shape](const Tensor *in_tensor) { return in_tensor->shape().size() != in0_shape.size(); })) {
84 return false;
85 }
86 } else {
87 // For other format(NCXHWX), the shape must be equal.
88 if (std::any_of(kernel->in_tensors().begin(), kernel->in_tensors().end(),
89 [&in0_shape](const Tensor *in_tensor) { return in_tensor->shape() != in0_shape; })) {
90 return false;
91 }
92 }
93 return true;
94 }
95
96 namespace {
97 using TransAxisFunc = std::function<int(kernel::KernelExec *, const TransInfoPair &)>;
__anond23b4b690502(kernel::KernelExec *, const TransInfoPair &) 98 TransAxisFunc kNoNeedTransAxisFunc = [](kernel::KernelExec *, const TransInfoPair &) { return RET_OK; };
__anond23b4b690602(kernel::KernelExec *, const TransInfoPair &) 99 TransAxisFunc kNotImplementedTransAxisFunc = [](kernel::KernelExec *, const TransInfoPair &) { return RET_ERROR; };
HandleArgMinMaxKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)100 int HandleArgMinMaxKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
101 auto arg_min_max_param = reinterpret_cast<ArgMinMaxParameter *>(kernel->op_parameter());
102 CHECK_NULL_RETURN(arg_min_max_param);
103 arg_min_max_param->axis_ = TransFormAxis(arg_min_max_param->axis_, trans);
104 return RET_OK;
105 }
106
HandleSoftMaxKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)107 int HandleSoftMaxKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
108 // nnacl need transpose op_parameter but BaseOperator beed transpose Primitive
109 auto param = reinterpret_cast<SoftmaxParameter *>(kernel->op_parameter());
110 CHECK_NULL_RETURN(param);
111 param->axis_ = TransFormAxis(param->axis_, trans);
112 return RET_OK;
113 }
114
HandleSplitKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)115 int HandleSplitKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
116 auto param = reinterpret_cast<SplitParameter *>(kernel->op_parameter());
117 CHECK_NULL_RETURN(param);
118 param->split_dim_ = TransFormAxis(param->split_dim_, trans);
119 return RET_OK;
120 }
121
HandleSqueezeKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)122 int HandleSqueezeKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
123 auto param = reinterpret_cast<SqueezeParameter *>(kernel->op_parameter());
124 CHECK_NULL_RETURN(param);
125 for (size_t i = 0; i < param->axis_size_; i++) {
126 param->axis_[i] = TransFormAxis(param->axis_[i], trans);
127 }
128 return RET_OK;
129 }
130
HandleUnSqueezeKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)131 int HandleUnSqueezeKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
132 auto param = reinterpret_cast<UnSqueezeParameter *>(kernel->op_parameter());
133 CHECK_NULL_RETURN(param);
134 param->axis_ = TransFormAxis(param->axis_, trans);
135 return RET_OK;
136 }
137
HandleStackKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)138 int HandleStackKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
139 auto param = reinterpret_cast<StackParameter *>(kernel->op_parameter());
140 CHECK_NULL_RETURN(param);
141 param->axis_ = TransFormAxis(param->axis_, trans);
142 return RET_OK;
143 }
144
HandleUnStackKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)145 int HandleUnStackKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
146 auto param = reinterpret_cast<UnstackParameter *>(kernel->op_parameter());
147 CHECK_NULL_RETURN(param);
148 param->axis_ = TransFormAxis(param->axis_, trans);
149 return RET_OK;
150 }
151
HandleConcatKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)152 int HandleConcatKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
153 auto concat_param = reinterpret_cast<ConcatParameter *>(kernel->op_parameter());
154 CHECK_NULL_RETURN(concat_param);
155 concat_param->axis_ = TransFormAxis(concat_param->axis_, trans);
156 return RET_OK;
157 }
158
159 namespace {
Handle0AxisCrop(const TransInfoPair & trans,CropParameter * crop_param)160 int Handle0AxisCrop(const TransInfoPair &trans, CropParameter *crop_param) {
161 auto offset = crop_param->offset_;
162 if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
163 crop_param->offset_[kNHWC_N] = offset[kNCHW_N];
164 crop_param->offset_[kNHWC_H] = offset[kNCHW_H];
165 crop_param->offset_[kNHWC_W] = offset[kNCHW_W];
166 crop_param->offset_[kNHWC_C] = offset[kNCHW_C];
167 return RET_OK;
168 }
169 if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
170 crop_param->offset_[kNCHW_N] = offset[kNHWC_N];
171 crop_param->offset_[kNCHW_H] = offset[kNHWC_H];
172 crop_param->offset_[kNCHW_W] = offset[kNHWC_W];
173 crop_param->offset_[kNCHW_C] = offset[kNHWC_C];
174 return RET_OK;
175 }
176 MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
177 return RET_ERROR;
178 }
179
Handle1AxisCrop(const TransInfoPair & trans,CropParameter * crop_param)180 int Handle1AxisCrop(const TransInfoPair &trans, CropParameter *crop_param) {
181 auto offset = crop_param->offset_;
182 if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
183 crop_param->offset_[kNHWC_H - 1] = offset[kNCHW_H - 1];
184 crop_param->offset_[kNHWC_W - 1] = offset[kNCHW_W - 1];
185 crop_param->offset_[kNHWC_C - 1] = offset[kNCHW_C - 1];
186 return RET_OK;
187 }
188 if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
189 crop_param->offset_[kNCHW_H - 1] = offset[kNHWC_H - 1];
190 crop_param->offset_[kNCHW_W - 1] = offset[kNHWC_W - 1];
191 crop_param->offset_[kNCHW_C - 1] = offset[kNHWC_C - 1];
192 return RET_OK;
193 }
194 MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
195 return RET_ERROR;
196 }
197
Handle2AxisCrop(const TransInfoPair & trans,CropParameter * crop_param,const lite::Tensor & input_tensor,const lite::Tensor & shape_tensor)198 int Handle2AxisCrop(const TransInfoPair &trans, CropParameter *crop_param, const lite::Tensor &input_tensor,
199 const lite::Tensor &shape_tensor) {
200 if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
201 if (input_tensor.Channel() != shape_tensor.Channel()) {
202 return RET_NO_CHANGE;
203 }
204 crop_param->offset_[4 - crop_param->axis_] = 0;
205 crop_param->axis_ = crop_param->axis_ - 1;
206 crop_param->offset_size_ = crop_param->offset_size_ + 1;
207 return RET_OK;
208 }
209 auto offset1 = crop_param->offset_[1];
210 auto offset0 = crop_param->offset_[0];
211 if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
212 if (input_tensor.Height() != shape_tensor.Height()) {
213 return RET_NO_CHANGE;
214 }
215 crop_param->axis_ = 1;
216 crop_param->offset_size_ = 3;
217 crop_param->offset_[0] = offset1;
218 crop_param->offset_[1] = 0;
219 crop_param->offset_[2] = offset0;
220 return RET_OK;
221 }
222 MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
223 return RET_ERROR;
224 }
225
Handle3AxisCrop(const TransInfoPair & trans,CropParameter * crop_param,const lite::Tensor & input_tensor,const lite::Tensor & shape_tensor)226 int Handle3AxisCrop(const TransInfoPair &trans, CropParameter *crop_param, const lite::Tensor &input_tensor,
227 const lite::Tensor &shape_tensor) {
228 if (IsSameTranspose(trans, kNCHW2NHWCTrans)) {
229 if (input_tensor.Channel() != shape_tensor.Channel()) {
230 return RET_NO_CHANGE;
231 }
232 crop_param->offset_[4 - crop_param->axis_] = 0;
233 crop_param->axis_ = crop_param->axis_ - 1;
234 crop_param->offset_size_ = crop_param->offset_size_ + 1;
235 return RET_OK;
236 }
237 auto offset0 = crop_param->offset_[0];
238 if (IsSameTranspose(trans, kNHWC2NCHWTrans)) {
239 if (input_tensor.Height() != shape_tensor.Height() || input_tensor.Width() != shape_tensor.Width()) {
240 return RET_NO_CHANGE;
241 }
242 crop_param->axis_ = 1;
243 crop_param->offset_size_ = 3;
244 crop_param->offset_[0] = offset0;
245 crop_param->offset_[1] = 0;
246 crop_param->offset_[2] = 0;
247 return RET_OK;
248 }
249 MS_LOG(ERROR) << "Unknown transpose info: from " << trans.src_format_ << " to " << trans.dst_format_;
250 return RET_ERROR;
251 }
252 } // namespace
253
HandleCropKernel(const kernel::KernelExec * kernel,const TransInfoPair & trans)254 int HandleCropKernel(const kernel::KernelExec *kernel, const TransInfoPair &trans) {
255 auto crop_param = reinterpret_cast<CropParameter *>(kernel->op_parameter());
256 CHECK_NULL_RETURN(crop_param);
257 auto inputs = kernel->in_tensors();
258 for (const auto &input : inputs) {
259 auto shape = input->shape();
260 if (shape.size() != DIMENSION_4D) {
261 return RET_NO_CHANGE;
262 }
263 if (std::any_of(shape.begin(), shape.end(), [](const int &dim) { return dim < 0; })) {
264 return RET_NO_CHANGE;
265 }
266 }
267 if (crop_param->axis_ == 0) {
268 return Handle0AxisCrop(trans, crop_param);
269 }
270 if (crop_param->axis_ == 1) {
271 return Handle1AxisCrop(trans, crop_param);
272 }
273 if (crop_param->axis_ == 2) {
274 return Handle2AxisCrop(trans, crop_param, *inputs[0], *inputs[1]);
275 }
276 if (crop_param->axis_ == 3) {
277 return Handle3AxisCrop(trans, crop_param, *inputs[0], *inputs[1]);
278 }
279 MS_LOG(ERROR) << "axis of parameter of Crop out of range, input dimension: 4, axis: " << crop_param->axis_;
280 return RET_ERROR;
281 }
282
283 // bool value determines whether the kernel has axis attribute or not.
284 // If bool value is true, the single kernel can be processd only for NHWC2NCHW or NCHW2NHWC.
285 static const std::unordered_map<schema::PrimitiveType, TransAxisFunc> kTransAxisFuncs = {
286 {schema::PrimitiveType_Abs, kNoNeedTransAxisFunc},
287 {schema::PrimitiveType_Activation, kNoNeedTransAxisFunc},
288 {schema::PrimitiveType_AddFusion, kNoNeedTransAxisFunc},
289 {schema::PrimitiveType_AddN, kNoNeedTransAxisFunc},
290 {schema::PrimitiveType_ArgMaxFusion, HandleArgMinMaxKernel},
291 {schema::PrimitiveType_ArgMinFusion, HandleArgMinMaxKernel},
292 {schema::PrimitiveType_Cast, kNoNeedTransAxisFunc},
293 {schema::PrimitiveType_Ceil, kNoNeedTransAxisFunc},
294 {schema::PrimitiveType_Clip, kNoNeedTransAxisFunc},
295 {schema::PrimitiveType_Concat, HandleConcatKernel},
296 {schema::PrimitiveType_Cos, kNoNeedTransAxisFunc},
297 {schema::PrimitiveType_Crop, HandleCropKernel},
298 {schema::PrimitiveType_DivFusion, kNoNeedTransAxisFunc},
299 {schema::PrimitiveType_Elu, kNoNeedTransAxisFunc},
300 {schema::PrimitiveType_Eltwise, kNoNeedTransAxisFunc},
301 {schema::PrimitiveType_Equal, kNoNeedTransAxisFunc},
302 {schema::PrimitiveType_ExpFusion, kNoNeedTransAxisFunc},
303 {schema::PrimitiveType_Floor, kNoNeedTransAxisFunc},
304 {schema::PrimitiveType_FloorDiv, kNoNeedTransAxisFunc},
305 {schema::PrimitiveType_FloorMod, kNoNeedTransAxisFunc},
306 {schema::PrimitiveType_Greater, kNoNeedTransAxisFunc},
307 {schema::PrimitiveType_GreaterEqual, kNoNeedTransAxisFunc},
308 {schema::PrimitiveType_Less, kNoNeedTransAxisFunc},
309 {schema::PrimitiveType_LessEqual, kNoNeedTransAxisFunc},
310 {schema::PrimitiveType_Log, kNoNeedTransAxisFunc},
311 {schema::PrimitiveType_LogicalAnd, kNoNeedTransAxisFunc},
312 {schema::PrimitiveType_LogicalNot, kNoNeedTransAxisFunc},
313 {schema::PrimitiveType_LogicalOr, kNoNeedTransAxisFunc},
314 {schema::PrimitiveType_Maximum, kNoNeedTransAxisFunc},
315 {schema::PrimitiveType_Minimum, kNoNeedTransAxisFunc},
316 {schema::PrimitiveType_Mod, kNoNeedTransAxisFunc},
317 {schema::PrimitiveType_MulFusion, kNoNeedTransAxisFunc},
318 {schema::PrimitiveType_Neg, kNoNeedTransAxisFunc},
319 {schema::PrimitiveType_NotEqual, kNoNeedTransAxisFunc},
320 {schema::PrimitiveType_PowFusion, kNoNeedTransAxisFunc},
321 {schema::PrimitiveType_QuantDTypeCast, kNoNeedTransAxisFunc},
322 {schema::PrimitiveType_RealDiv, kNoNeedTransAxisFunc},
323 {schema::PrimitiveType_Round, kNoNeedTransAxisFunc},
324 {schema::PrimitiveType_Rsqrt, kNoNeedTransAxisFunc},
325 {schema::PrimitiveType_Sin, kNoNeedTransAxisFunc},
326 {schema::PrimitiveType_SliceFusion, kNotImplementedTransAxisFunc},
327 {schema::PrimitiveType_Softmax, HandleSoftMaxKernel},
328 {schema::PrimitiveType_Split, HandleSplitKernel},
329 {schema::PrimitiveType_Sqrt, kNoNeedTransAxisFunc},
330 {schema::PrimitiveType_Squeeze, HandleSqueezeKernel},
331 {schema::PrimitiveType_Square, kNoNeedTransAxisFunc},
332 {schema::PrimitiveType_SquaredDifference, kNoNeedTransAxisFunc},
333 {schema::PrimitiveType_Stack, HandleStackKernel},
334 {schema::PrimitiveType_StridedSlice, kNotImplementedTransAxisFunc},
335 {schema::PrimitiveType_SubFusion, kNoNeedTransAxisFunc},
336 {schema::PrimitiveType_Unsqueeze, HandleUnSqueezeKernel},
337 {schema::PrimitiveType_Unstack, HandleUnStackKernel},
338 {schema::PrimitiveType_LogSoftmax, HandleSoftMaxKernel},
339 {schema::PrimitiveType_Erf, kNoNeedTransAxisFunc},
340 };
341 } // namespace
342
CrossKernelFusionPreCheck(const kernel::KernelExec * kernel,TransInfoPair * pre_trans,TransInfoPair * post_trans)343 bool TransposeStrategy::CrossKernelFusionPreCheck(const kernel::KernelExec *kernel, TransInfoPair *pre_trans,
344 TransInfoPair *post_trans) {
345 if (kTransAxisFuncs.find(kernel::SchemaType(kernel->type())) == kTransAxisFuncs.end()) {
346 return false;
347 }
348 auto input_count = GetTransCount(kernel->in_kernels(), pre_trans);
349 auto output_count = GetTransCount(kernel->out_kernels(), post_trans);
350 if (IsSameTranspose(*pre_trans, *post_trans)) {
351 return false;
352 }
353 if (!IsOppositiveTranspose(*pre_trans, *post_trans)) {
354 return false;
355 }
356 auto in_and_out_size = kernel->in_tensors().size() + kernel->out_kernels().size();
357 if ((input_count + output_count) <= in_and_out_size / C2NUM) {
358 MS_LOG(DEBUG) << "The fusion can't decrease transpose op number.";
359 return false;
360 }
361 if (IsNoneTranspose(*pre_trans)) {
362 pre_trans->src_format_ = post_trans->dst_format_;
363 pre_trans->dst_format_ = post_trans->src_format_;
364 }
365 if (IsNoneTranspose(*post_trans)) {
366 post_trans->src_format_ = pre_trans->dst_format_;
367 post_trans->dst_format_ = pre_trans->src_format_;
368 }
369
370 if (((!IsSameTranspose(*post_trans, kNCHW2NHWCTrans)) && (!IsSameTranspose(*post_trans, kNHWC2NCHWTrans))) &&
371 kTransAxisFuncs.at(kernel::SchemaType(kernel->type()))) {
372 return false;
373 }
374 if (!CheckInTensorsShape(kernel, (Format)(post_trans->dst_format_))) {
375 return false;
376 }
377 return true;
378 }
379
TryTransKernelAxis(kernel::KernelExec * kernel,const TransInfoPair & trans)380 int TransposeStrategy::TryTransKernelAxis(kernel::KernelExec *kernel, const TransInfoPair &trans) {
381 auto trans_axis_func = kTransAxisFuncs.find(kernel::SchemaType(kernel->type()));
382 if (trans_axis_func == kTransAxisFuncs.end() || trans_axis_func->second == nullptr) {
383 MS_LOG(ERROR) << "Can't find the axis change function for " << kernel->name();
384 return RET_ERROR;
385 }
386 return trans_axis_func->second(kernel, trans);
387 }
388 } // namespace mindspore::lite::pass
389